diff --git a/shacrypt.py b/shacrypt.py index 2e26ddf..67f91de 100644 --- a/shacrypt.py +++ b/shacrypt.py @@ -117,35 +117,38 @@ def shacrypt_enc(key, plaintext): del plaintext del cipher_key + # Contruct the HMACed part of ciphertext + hmaced = b''.join(( + hkdf_salt, + cipher_nonce, + ciphered + )) + del ciphered + # HMAC - hmac = hmac_sha256(hmac_key, ciphered) + hmac = hmac_sha256(hmac_key, hmaced) del hmac_key # Construct the full ciphertext - return b''.join(( - hkdf_salt, - cipher_nonce, - ciphered, - hmac - )) + return hmaced + hmac class AuthenticationError(Exception): pass def shacrypt_dec(key, ciphertext): assert len(key) == 256//8 - # Extract the IVs - hkdf_salt = ciphertext[0:256//8] - cipher_nonce = ciphertext[256//8:256//8 + 256//8] - - # Extract the main part of ciphertext - ciphered = ciphertext[2 * 256//8:-sha256_outputsize] + # Extract the HMACed part of ciphertext + hmaced = ciphertext[:-sha256_outputsize] # Extract the expected HMAC expected_hmac = ciphertext[-sha256_outputsize:] del ciphertext + # Extract the IVs + hkdf_salt = hmaced[0:256//8] + cipher_nonce = hmaced[256//8:256//8 + 256//8] + # Derive keys keys = hkdf_sha256(hkdf_salt, key, b'', 512//8) del key @@ -160,13 +163,17 @@ def shacrypt_dec(key, ciphertext): del keys # Verify HMAC - hmac = hmac_sha256(hmac_key, ciphered) + hmac = hmac_sha256(hmac_key, hmaced) del hmac_key if not secrets.compare_digest(expected_hmac, hmac): raise AuthenticationError del expected_hmac del hmac + # Extract the ciphered part of the ciphertext + ciphered = hmaced[2 * 256//8:] + del hmaced + # Decrypt plaintext = bytearray() for cipheredbyte, keybyte in zip(ciphered, hmac_sha256_ctr_keystream(cipher_nonce, cipher_key)):