Revamp filtering entries

This commit is contained in:
Juhani Krekelä 2018-08-31 20:21:08 +03:00
parent 4a3107a001
commit 004ec12ec3
2 changed files with 62 additions and 12 deletions

View File

@ -1,16 +1,23 @@
import enum
from collections import namedtuple
import entry
import hashing
# TODO: Include a thing for checking what hosts match a given fingerprint
# Result(str/None, u16/None, str)
Result = namedtuple('Result', ['domain', 'port', 'comment'])
def check_fingerprint(entries, domain, port, fingerprint):
"""check_fingerprint([Entry], str, u16, bytes[32]) → ([str]: successes, [str]: fails)
def check(entries, domain, port, fingerprint):
"""check([Entry], str, u16, bytes[32]) → ([Result]: successes, [Result]: fails, [Result]: same_fingerprint)
Checks if the given host is found with the given fingerprint.
The successes and fails lists returned by the function have the
comments for the hosts that match and have the same fingerpring and
the hosts that match but have a different fingerprint, respectively"""
successes contains ones where both the host and the fingerprint match.
fails contains ones where host matches but the fingerprint doesn't.
same_fingerprint contains ones where fingerprint matches but the
host doesn't. Their .domain and .port will be None"""
assert type(entries) == list and all(type(i) == entry.Entry for i in entries)
assert type(domain) == str
assert type(port) == int and 0 <= port <= (1<<16) - 1
@ -18,25 +25,36 @@ def check_fingerprint(entries, domain, port, fingerprint):
# Normalize the host here, so we don't have to do it every time we
# check for a possible match
normalized_hosts = [entry.normalize_host(domain, port)]
normalized_hosts = {port: entry.normalize_host(domain, port)}
# If we are looking at non-22 port, also check the general form of
# the host without a port specifier. This seems to be how OpenSSH
# does it too
if port != 22:
normalized_hosts.append(entry.normalize_host(domain, 22))
normalized_hosts[22] = entry.normalize_host(domain, 22)
successes = []
fails = []
same_fingerprint = []
for possible_match in entries:
for normalized_host in normalized_hosts:
any_host_matched = False
for current_port, normalized_host in normalized_hosts.items():
hashed_host = hashing.hash_with_salt(normalized_host, possible_match.salt)
if hashed_host == possible_match.hashed_host:
if fingerprint == possible_match.fingerprint:
# Fingerprint matches, it passes
successes.append(possible_match.comment)
successes.append(Result(domain, current_port, possible_match.comment))
any_host_matched = True
else:
# Fingerprint different, it fails
fails.append(possible_match.comment)
fails.append(Result(domain, current_port, possible_match.comment))
return successes, fails
if not any_host_matched and fingerprint == possible_match.fingerprint:
# Host is not the same, but the fingerprint
# matches
print(possible_match)#debg
same_fingerprint.append(Result(None, None, possible_match.comment))
return successes, fails, same_fingerprint

View File

@ -48,3 +48,35 @@ def create_entry(domain, port, fingerprint, comment):
raise UnacceptableComment('Comment contains newlines')
return Entry(salt, hashed_host, fingerprint, comment)
def filter_by_host(entries, domain, port):
"""filter_by_host([Entry], str, u16) → [Entry]
Return hosts that match given domain and port."""
assert type(entries) == list and all(type(i) == Entry for i in entries)
assert type(domain) == str
assert type(port) == int and 0 <= port <= (1<<16) - 1
# Normalize the host here, so we don't have to do it every time we
# check for a match
normalized_host = normalize_host(domain, port)
entries = []
for entry in entries:
hashed_host = hashing.hash_with_salt(normalized_host, entry.salt)
if hashed_host == entry.hashed_host:
entries.append(entry)
return entries
def filter_by_fingerprint(entries, fingerprint):
"""filter_by_fingerprint([Entry], bytes[32]) → [Entry]
Return hosts that match given fingerprint."""
assert type(entries) == list and all(type(i) == Entry for i in entries)
assert type(fingerprint) == bytes and len(fingerprint) == 32
entries = []
for entry in entries:
if fingerprint == entry.fingerprint:
entries.append(entry)
return entries