shacrypt/shacrypt.py

224 lines
5.6 KiB
Python

#!/usr/bin/env python3
import hashlib
import secrets
import sys
sha256_blocksize = hashlib.sha256().block_size
sha256_outputsize = hashlib.sha256().digest_size
def xor(x, y):
assert len(x) == len(y)
for a, b in zip(x, y):
yield a ^ b
def hmac_sha256(key, message):
# Handle long keys
# Makes the key the length of hash output
if len(key) > sha256_blocksize:
key = sha256(key)
# Handle short keys
# An if, not an elif, since output size < blocksize
if len(key) < sha256_blocksize:
key = key + b'\x00' * (sha256_blocksize - len(key))
ipad = b'\x36' * sha256_blocksize
# Do inner hash
m = hashlib.sha256()
m.update(bytes(xor(key, ipad)))
m.update(message)
inner = m.digest()
opad = b'\x5c' * sha256_blocksize
# Do outer hash
m = hashlib.sha256()
m.update(bytes(xor(key, opad)))
m.update(inner)
outer = m.digest()
return outer
def ceildiv(p, q):
assert p >= 0
assert q > 0
truncated_result = p // q
remainder = p % q
if remainder > 0:
return truncated_result + 1
else:
return truncated_result
def hkdf_sha256(salt, key_material, info, length):
assert length <= 255
# Extract
if salt == b'':
salt = b'\x00' * sha256_outputsize
pseudorandom_key = hmac_sha256(salt, key_material)
# Expand
# output[n] corresponds to the T(n) in RFC5869
# Since T(0) is an empty string, initialize output as [b'']
output = [b'']
# In RFC5869 the indices for the parts we compute are in 1…N, but
# range(ceildiv(length, sha256_outputsize)) generates 0…N-1
for index_minus_one in range(ceildiv(length, sha256_outputsize)):
index = index_minus_one + 1
output.append(hmac_sha256(pseudorandom_key, output[index_minus_one] + info + bytes([index])))
# Cut the output into the size requested
return b''.join(output)[:length]
def hmac_sha256_ctr_keystream(nonce, key):
# We encrypt a 512 bit block that consist of a 256 bit nonce and a
# 256 bit counter encoded in big-endian format
assert len(nonce) == 256//8
assert len(key) == 256//8
def encode_counter(counter):
encoded_reverse = bytearray()
for i in range(256//8):
encoded_reverse.append(counter & 0xff)
counter >>= 8
return bytes(reversed(encoded_reverse))
counter = 0
while True:
yield from hmac_sha256(key, nonce + encode_counter(counter))
counter += 1
def shacrypt_enc(key, plaintext):
assert len(key) == 256//8
# Generate the IVs
hkdf_salt = secrets.token_bytes(256//8)
cipher_nonce = secrets.token_bytes(256//8)
# Derive keys
keys = hkdf_sha256(hkdf_salt, key, b'', 512//8)
del key
# Create HMAC key before the encryption one, so that an attacker
# needs to run the full HKDF invocation to get to the encryption
# key, instead of just half of it which would be the case if they
# were the other way around
# No idea if this would end up helping against any attack but hey
# it's not hurting in the very least
hmac_key = keys[:256//8]
cipher_key = keys[256//8:]
del keys
# Encrypt
ciphered = bytearray()
for plaintextbyte, keybyte in zip(plaintext, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)):
ciphered.append(plaintextbyte ^ keybyte)
del plaintext
del cipher_key
# Contruct the HMACed part of ciphertext
hmaced = b''.join((
hkdf_salt,
cipher_nonce,
ciphered
))
del ciphered
# HMAC
hmac = hmac_sha256(hmac_key, hmaced)
del hmac_key
# Construct the full ciphertext
return hmaced + hmac
class AuthenticationError(Exception): pass
def shacrypt_dec(key, ciphertext):
assert len(key) == 256//8
# Extract the HMACed part of ciphertext
hmaced = ciphertext[:-sha256_outputsize]
# Extract the expected HMAC
expected_hmac = ciphertext[-sha256_outputsize:]
del ciphertext
# Extract the IVs
hkdf_salt = hmaced[0:256//8]
cipher_nonce = hmaced[256//8:256//8 + 256//8]
# Derive keys
keys = hkdf_sha256(hkdf_salt, key, b'', 512//8)
del key
# Create HMAC key before the encryption one, so that an attacker
# needs to run the full HKDF invocation to get to the encryption
# key, instead of just half of it which would be the case if they
# were the other way around
# No idea if this would end up helping against any attack but hey
# it's not hurting in the very least
hmac_key = keys[:256//8]
cipher_key = keys[256//8:]
del keys
# Verify HMAC
hmac = hmac_sha256(hmac_key, hmaced)
del hmac_key
if not secrets.compare_digest(expected_hmac, hmac):
raise AuthenticationError
del expected_hmac
del hmac
# Extract the ciphered part of the ciphertext
ciphered = hmaced[2 * 256//8:]
del hmaced
# Decrypt
plaintext = bytearray()
for cipheredbyte, keybyte in zip(ciphered, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)):
plaintext.append(cipheredbyte ^ keybyte)
del ciphered
del cipher_nonce
del cipher_key
return plaintext
def main():
if len(sys.argv) != 3:
print('Usage: %s enc|dec key' % sys.argv[0], file=sys.stderr)
sys.exit(1)
try:
key = bytes.fromhex(sys.argv[2])
except ValueError:
print('%s: Error: Key must be hex-encoded' % sys.argv[0], file=sys.stderr)
sys.exit(1)
if len(key) != 256//8:
print('%s: Error: Key must be 256 bits longs' % sys.argv[0], file=sys.stderr)
sys.exit(1)
if sys.argv[1] == 'enc':
plaintext = sys.stdin.buffer.read()
ciphertext = shacrypt_enc(key, plaintext)
sys.stdout.buffer.write(ciphertext)
elif sys.argv[1] == 'dec':
ciphertext = sys.stdin.buffer.read()
try:
plaintext = shacrypt_dec(key, ciphertext)
except AuthenticationError:
print('%s: Error: HMAC mismatch' % sys.argv[0], file=sys.stderr)
sys.exit(1)
sys.stdout.buffer.write(plaintext)
else:
print('Usage: %0 enc|dec key' % sys.argv[0], file=sys.stderr)
sys.exit(1)
sys.stdout.buffer.flush()
if __name__ == '__main__':
main()