diff --git a/src/entry.py b/src/entry.py index 9305ac5..81fcec2 100644 --- a/src/entry.py +++ b/src/entry.py @@ -7,6 +7,27 @@ Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment']) class UnacceptableComment(Exception): pass +def normalize_host(domain, port): + """normalize_host(str, u16) → bytes + Tranform a domain into the format in which it will be hashed""" + assert type(domain) == str + assert type(port) == int and 0 <= port <= (1<<16) - 1 + + # We want to have domain names reasonably normalized. This is why we + # convert all internationalized domain names to punycode and + # lowercase all domains. + # The reason the lowercasing happens after the punycoding is because + # that way we don't have to worry about Unicode case mapping: in + # case of IDN the IDNA codec handles that for us, and in case of an + # ASCII domain it passes through the IDNA unmodified + normalized_host = domain.encode('idna').lower() + + # If the port is not :22, we store [host]:port instead + if port != 22: + normalized_host = b'[%s]:%i' % (normalized_host, port) + + return normalized_host + def create_entry(domain, port, fingerprint, comment): """create_entry(str, u16, bytes[32], str) → Entry Given unprocessed host, a binary fingerprint and a comment, creates @@ -16,21 +37,11 @@ def create_entry(domain, port, fingerprint, comment): assert type(fingerprint) == bytes and len(fingerprint) == 32 assert type(comment) == str - # We want to have domain names reasonably normalized. This is why we - # convert all internationalized domain names to punycode and - # lowercase all domains. - # The reason the lowercasing happens after the punycoding is because - # that way we don't have to worry about Unicode case mapping: in - # case of IDN the IDNA codec handles that for us, and in case of an - # ASCII domain it passes through the IDNA unmodified - processed_host = domain.encode('idna').lower() - - # If the port is not :22, we store [host]:port instead - if port != 22: - processed_host = b'[%s]%i' % (processed_host, port) + # Normalize the host before hashing + normalized_host = normalize_host(domain, port) # Hash the host and store the salt - salt, hashed_host = hashing.hash_host(processed_host) + salt, hashed_host = hashing.hash_host(normalized_host) # Comment must not include newlines if '\n' in comment: