#!/usr/bin/env python3 import hashlib import secrets import sys sha256_blocksize = hashlib.sha256().block_size sha256_outputsize = hashlib.sha256().digest_size def xor(x, y): assert len(x) == len(y) for a, b in zip(x, y): yield a ^ b def hmac_sha256(key, message): # Handle long keys # Makes the key the length of hash output if len(key) > sha256_blocksize: key = sha256(key) # Handle short keys # An if, not an elif, since output size < blocksize if len(key) < sha256_blocksize: key = key + b'\x00' * (sha256_blocksize - len(key)) ipad = b'\x36' * sha256_blocksize # Do inner hash m = hashlib.sha256() m.update(bytes(xor(key, ipad))) m.update(message) inner = m.digest() opad = b'\x5c' * sha256_blocksize # Do outer hash m = hashlib.sha256() m.update(bytes(xor(key, opad))) m.update(inner) outer = m.digest() return outer def ceildiv(p, q): assert p >= 0 assert q > 0 truncated_result = p // q remainder = p % q if remainder > 0: return truncated_result + 1 else: return truncated_result def hkdf_sha256(salt, key_material, info, length): assert length <= 255 # Extract if salt == b'': salt = b'\x00' * sha256_outputsize pseudorandom_key = hmac_sha256(salt, key_material) # Expand # output[n] corresponds to the T(n) in RFC5869 # Since T(0) is an empty string, initialize output as [b''] output = [b''] # In RFC5869 the indices for the parts we compute are in 1…N, but # range(ceildiv(length, sha256_outputsize)) generates 0…N-1 for index_minus_one in range(ceildiv(length, sha256_outputsize)): index = index_minus_one + 1 output.append(hmac_sha256(pseudorandom_key, output[index_minus_one] + info + bytes([index]))) # Cut the output into the size requested return b''.join(output)[:length] def hmac_sha256_ctr_keystream(nonce, key): # We encrypt a 512 bit block that consist of a 256 bit nonce and a # 256 bit counter encoded in big-endian format assert len(nonce) == 256//8 assert len(key) == 256//8 def encode_counter(counter): encoded_reverse = bytearray() for i in range(256//8): encoded_reverse.append(counter & 0xff) counter >>= 8 return bytes(reversed(encoded_reverse)) counter = 0 while True: yield from hmac_sha256(key, nonce + encode_counter(counter)) counter += 1 def shacrypt_enc(key, plaintext): assert len(key) == 256//8 # Generate the IVs hkdf_salt = secrets.token_bytes(256//8) cipher_nonce = secrets.token_bytes(256//8) # Derive keys keys = hkdf_sha256(hkdf_salt, key, b'', 512//8) del key # Create HMAC key before the encryption one, so that an attacker # needs to run the full HKDF invocation to get to the encryption # key, instead of just half of it which would be the case if they # were the other way around # No idea if this would end up helping against any attack but hey # it's not hurting in the very least hmac_key = keys[:256//8] cipher_key = keys[256//8:] del keys # Encrypt ciphered = bytearray() for plaintextbyte, keybyte in zip(plaintext, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)): ciphered.append(plaintextbyte ^ keybyte) del plaintext del cipher_key # Contruct the HMACed part of ciphertext hmaced = b''.join(( hkdf_salt, cipher_nonce, ciphered )) del ciphered # HMAC hmac = hmac_sha256(hmac_key, hmaced) del hmac_key # Construct the full ciphertext return hmaced + hmac class AuthenticationError(Exception): pass def shacrypt_dec(key, ciphertext): assert len(key) == 256//8 # Extract the HMACed part of ciphertext hmaced = ciphertext[:-sha256_outputsize] # Extract the expected HMAC expected_hmac = ciphertext[-sha256_outputsize:] del ciphertext # Extract the IVs hkdf_salt = hmaced[0:256//8] cipher_nonce = hmaced[256//8:256//8 + 256//8] # Derive keys keys = hkdf_sha256(hkdf_salt, key, b'', 512//8) del key # Create HMAC key before the encryption one, so that an attacker # needs to run the full HKDF invocation to get to the encryption # key, instead of just half of it which would be the case if they # were the other way around # No idea if this would end up helping against any attack but hey # it's not hurting in the very least hmac_key = keys[:256//8] cipher_key = keys[256//8:] del keys # Verify HMAC hmac = hmac_sha256(hmac_key, hmaced) del hmac_key if not secrets.compare_digest(expected_hmac, hmac): raise AuthenticationError del expected_hmac del hmac # Extract the ciphered part of the ciphertext ciphered = hmaced[2 * 256//8:] del hmaced # Decrypt plaintext = bytearray() for cipheredbyte, keybyte in zip(ciphered, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)): plaintext.append(cipheredbyte ^ keybyte) del ciphered del cipher_nonce del cipher_key return plaintext def main(): if len(sys.argv) != 3: print('Usage: %s enc|dec key' % sys.argv[0], file=sys.stderr) sys.exit(1) try: key = bytes.fromhex(sys.argv[2]) except ValueError: print('%s: Error: Key must be hex-encoded' % sys.argv[0], file=sys.stderr) sys.exit(1) if len(key) != 256//8: print('%s: Error: Key must be 256 bits longs' % sys.argv[0], file=sys.stderr) sys.exit(1) if sys.argv[1] == 'enc': plaintext = sys.stdin.buffer.read() ciphertext = shacrypt_enc(key, plaintext) sys.stdout.buffer.write(ciphertext) elif sys.argv[1] == 'dec': ciphertext = sys.stdin.buffer.read() try: plaintext = shacrypt_dec(key, ciphertext) except AuthenticationError: print('%s: Error: HMAC mismatch' % sys.argv[0], file=sys.stderr) sys.exit(1) sys.stdout.buffer.write(plaintext) else: print('Usage: %0 enc|dec key' % sys.argv[0], file=sys.stderr) sys.exit(1) sys.stdout.buffer.flush() if __name__ == '__main__': main()