diff --git a/src/main-export-known-hosts.py b/src/main-export-known-hosts.py index 15501a9..7253d73 100644 --- a/src/main-export-known-hosts.py +++ b/src/main-export-known-hosts.py @@ -4,12 +4,14 @@ import process_known_hosts import write_file def main(): - entries = [] # TODO: Don't hardcode # TODO: Handle errors with open(sys.argv[1], 'r') as f: - for line in f: - entries.extend(process_known_hosts.process_line(line)) + try: + entries = process_known_hosts.process_file(f) + except Exception as err: + print('Error: %s' % err, file=sys.stderr) + sys.exit(1) with open('known_hosts.sshwot', 'wb') as f: write_file.write(f, entries) diff --git a/src/process_known_hosts.py b/src/process_known_hosts.py index 55be930..e9c9687 100644 --- a/src/process_known_hosts.py +++ b/src/process_known_hosts.py @@ -3,10 +3,27 @@ import hashlib import entry -# TODO: Include line number in the error -class KnownHostsSyntaxError(Exception): pass +class KnownHostsSyntaxError(Exception): + def __init__(self, string): + self.string = string + self.line = None -class HashedHostError(Exception): pass + def __str__(self): + if self.line == None: + return self.string + else: + return 'Line %i: %s' % (self.line, self.string) + +class HashedHostError(Exception): + def __init__(self, string): + self.string = string + self.line = None + + def __str__(self): + if self.line == None: + return self.string + else: + return 'Line %i: %s' % (self.line, self.string) def process_line(line): # TODO: Add a way to skip IPs @@ -15,8 +32,6 @@ def process_line(line): a list of Entries based on it.""" assert type(line) == str - # TODO: Skip over IPs somehow? - # Remove trailing newlines if line[-1] == '\n': line = line[:-1] @@ -78,3 +93,19 @@ def process_line(line): entries.append(entry.create_entry(domain, port, fingerprint, '')) return entries + +def process_file(f): + """process_file(file(r)) → [Entry] + Given a file in the .ssh/known_hosts format, create a list of + entries""" + + entries = [] + # Line numbers are 1-indexed but enumerate 0-indexes + for linenum_minus_one, line in enumerate(f): + try: + entries.extend(process_line(line)) + except (KnownHostsSyntaxError, HashedHostError) as err: + err.line = linenum_minus_one + 1 + raise err + + return entries diff --git a/src/read_file.py b/src/read_file.py index 4d59d0a..619a645 100644 --- a/src/read_file.py +++ b/src/read_file.py @@ -2,10 +2,27 @@ import base64 import entry -# TODO: Include file number in the error info -class FileFormatError(Exception): pass +class FileFormatError(Exception): + def __init__(self, string): + self.string = string + self.line = None -class VersionMismatch(Exception): pass + def __str__(self): + if self.line == None: + return self.string + else: + return 'Line %i: %s' % (self.line, self.string) + +class VersionMismatch(Exception): + def __init__(self, string): + self.string = string + self.line = None + + def __str__(self): + if self.line == None: + return self.string + else: + return 'Line %i: %s' % (self.line, self.string) def parse_header(header): """parse_header(bytes) → str @@ -105,10 +122,21 @@ def read(f): if len(lines) == 0: raise FileFormatError('Missing header') - file_comment = parse_header(lines[0]) + try: + file_comment = parse_header(lines[0]) + except (FileFormatError, VersionMismatch) as err: + err.line = 1 + raise err entries = [] - for line in lines[1:]: - entries.append(parse_entry(line)) + # Since line numbers are 1-indexed while lists in python are + # 0-indexed and we handle the first one separately, first one in the + # list is line 2 + for linenum_minus_2, line in enumerate(lines[1:]): + try: + entries.append(parse_entry(line)) + except FileFormatError as err: + err.line = linenum_minus_2 + 2 + raise err return entries, file_comment