diff --git a/src/check_fingerprint.py b/src/check_fingerprint.py new file mode 100644 index 0000000..3625d18 --- /dev/null +++ b/src/check_fingerprint.py @@ -0,0 +1,51 @@ +import enum + +import entry +import hashing + +class result(enum.Enum): + notfound, ok, fail = range(3) + +def check_fingerprint(entries, domain, port, fingerprint): + """check_fingerprint([Entry], str, u16, bytes[32]) → (enum result: result, str / None: comment) + Checks if the given host is found with the given fingerprint. + Will return the comment on the host if the host is found, regardless + of whether the fingerprint checks out.""" + assert type(entries) == list and all(type(i) == entry.Entry for i in entries) + assert type(domain) == str + assert type(port) == int and 0 <= port <= (1<<16) - 1 + assert type(fingerprint) == bytes and len(fingerprint) == 32 + + # Normalize the host here, so we don't have to do it every time we + # check for a possible match + normalized_hosts = [entry.normalize_host(domain, port)] + + # If we are looking at non-22 port, also check the general form of + # the host without a port specifier. This seems to be how OpenSSH + # does it too + if port != 22: + normalized_hosts.append(entry.normalize_host(domain, 22)) + + for possible_match in entries: + for normalized_host in normalized_hosts: + hashed_host = hashing.hash_with_salt(normalized_host, possible_match.salt) + if hashed_host == possible_match.hashed_host: + # Convert the comment to a string + # We put replacement characters where + # decoding fails instead of throwing an + # error, because even whilethe comment + # field must be valid utf-8, failing in this + # situation is bad UX + comment = possible_match.comment.decode('utf-8', errors = 'replace') + # TODO: Justify this + # We only care about the first match, so we + # return here + if fingerprint == possible_match.fingerprint: + # Fingerprint matches, it passes + return (result.ok, comment) + else: + # Fingerprint different, it fails + return (result.fail, comment) + + # We did not match, tell the caller so + return (result.notfound, None) diff --git a/src/process_known_hosts.py b/src/process_known_hosts.py index ed6cb6f..886d0c1 100644 --- a/src/process_known_hosts.py +++ b/src/process_known_hosts.py @@ -13,6 +13,8 @@ def process_line(line): a list of Entries based on it.""" assert type(line) == str + # TODO: Skip over IPs somehow? + # Remove trailing newlines if line[-1] == '\n': line = line[:-1]