Split the host normalization into its own func and actually handle non-22 port right

This commit is contained in:
Juhani Krekelä 2018-08-28 20:14:21 +03:00
parent 433f9baf31
commit 320445ab25
1 changed files with 24 additions and 13 deletions

View File

@ -7,6 +7,27 @@ Entry = namedtuple('Entry', ['salt', 'hashed_host', 'fingerprint', 'comment'])
class UnacceptableComment(Exception): pass
def normalize_host(domain, port):
"""normalize_host(str, u16) → bytes
Tranform a domain into the format in which it will be hashed"""
assert type(domain) == str
assert type(port) == int and 0 <= port <= (1<<16) - 1
# We want to have domain names reasonably normalized. This is why we
# convert all internationalized domain names to punycode and
# lowercase all domains.
# The reason the lowercasing happens after the punycoding is because
# that way we don't have to worry about Unicode case mapping: in
# case of IDN the IDNA codec handles that for us, and in case of an
# ASCII domain it passes through the IDNA unmodified
normalized_host = domain.encode('idna').lower()
# If the port is not :22, we store [host]:port instead
if port != 22:
normalized_host = b'[%s]:%i' % (normalized_host, port)
return normalized_host
def create_entry(domain, port, fingerprint, comment):
"""create_entry(str, u16, bytes[32], str) → Entry
Given unprocessed host, a binary fingerprint and a comment, creates
@ -16,21 +37,11 @@ def create_entry(domain, port, fingerprint, comment):
assert type(fingerprint) == bytes and len(fingerprint) == 32
assert type(comment) == str
# We want to have domain names reasonably normalized. This is why we
# convert all internationalized domain names to punycode and
# lowercase all domains.
# The reason the lowercasing happens after the punycoding is because
# that way we don't have to worry about Unicode case mapping: in
# case of IDN the IDNA codec handles that for us, and in case of an
# ASCII domain it passes through the IDNA unmodified
processed_host = domain.encode('idna').lower()
# If the port is not :22, we store [host]:port instead
if port != 22:
processed_host = b'[%s]%i' % (processed_host, port)
# Normalize the host before hashing
normalized_host = normalize_host(domain, port)
# Hash the host and store the salt
salt, hashed_host = hashing.hash_host(processed_host)
salt, hashed_host = hashing.hash_host(normalized_host)
# Comment must not include newlines
if '\n' in comment: