import base64 import binascii import getopt import hashlib import os import secrets import socket import sys def error(*args): print('%s: Error:' % os.path.basename(sys.argv[0]), *args, file = sys.stderr) sys.exit(1) # ------------------------------------------------------------------ # Client - server communication # ------------------------------------------------------------------ def enc_len(length): assert 0 <= length < 1<<16 high_byte = length >> 8 low_byte = length & 0xff return bytes([low_byte, high_byte]) def dec_len(encoded): low_byte, high_byte = encoded return (high_byte << 8) | low_byte def server(server_pubkey, port): sock = None for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) except OSError: sock = None continue try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(sa) sock.listen(1) except OSError: sock.close() sock = None continue break if sock is None: error('Could not bind on port %i' % port) conn, addr = sock.accept() with conn: remote_host, remote_port = addr print('Connection from %s' % remote_host) client_pubkey_len = dec_len(conn.recv(2)) client_pubkey = bytearray() while len(client_pubkey) < client_pubkey_len: data = conn.recv(1024) if not data: sock.close() error('Client public key could not be received') client_pubkey.extend(data) conn.sendall(enc_len(len(server_pubkey))) conn.sendall(server_pubkey) conn.shutdown(socket.SHUT_RDWR) sock.close() return bytes(client_pubkey) def client(client_pubkey, host, port): sock = None for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) except OSError: sock = None continue try: sock.connect(sa) except OSError: sock.close() sock = None continue break if sock is None: error('Could not connect to %s on port %i' % (host, port)) print('Connected to %s' % host) with sock: sock.sendall(enc_len(len(client_pubkey))) sock.sendall(client_pubkey) server_pubkey_len = dec_len(sock.recv(2)) server_pubkey = bytearray() while len(server_pubkey) < server_pubkey_len: data = sock.recv(1024) if not data: error('Server public key could not be received') server_pubkey.extend(data) sock.shutdown(socket.SHUT_RDWR) return bytes(server_pubkey) # ------------------------------------------------------------------ # Hash # ------------------------------------------------------------------ def sha512(x): h = hashlib.sha512() h.update(x) return h.digest() def auth_hash(client_pubkey, server_pubkey): client_pubkey_hash = sha512(enc_len(len(client_pubkey)) + client_pubkey) server_pubkey_hash = sha512(enc_len(len(server_pubkey)) + server_pubkey) combined_hash = sha512(client_pubkey_hash + server_pubkey_hash) truncated_hash = combined_hash[:16] hash_check = sha512(truncated_hash)[:4] return truncated_hash + hash_check def chunk(sliceable, length): for i in range(0, len(sliceable), length): yield sliceable[i:i + length] def b32_encode(binary): # Kishib uses a modified base32 compared to RFC 4648 # * All letters are lowercase # * 'i' is replaced with '8' and 'o' is replaced with '9' encoded = base64.b32encode(binary).decode() encoded = encoded.lower() encoded = encoded.replace('i', '8').replace('o', '9') return encoded def format_hash(hash_bytes): hash_base32 = b32_encode(hash_bytes) chunked_base32 = chunk(hash_base32, 4) return '-'.join(chunked_base32) class HashFormatError(Exception): pass class HashChecksumError(Exception): pass def parse_hash(auth_hash): # Hash consists of 8 segments of four characters segments = auth_hash.split('-') if len(segments) != 8: raise HashFormatError if not all(len(i) == 4 for i in segments): raise HashFormatError combined = ''.join(segments) if 'i' in combined or 'o' in combined: raise HashFormatError standard_base32 = combined.replace('8', 'i').replace('9', 'o').upper() try: binary = base64.b32decode(standard_base32) except binascii.Error as err: raise HashFormatError from err truncated_hash = binary[:16] hash_check = binary[16:] # Using secrets.compare_digest is not necessary, but I feel it is a # good habit to avoid comparing hashes with variable-timed comparisons if not secrets.compare_digest(sha512(truncated_hash)[:4], hash_check): raise HashChecksumError return binary # ------------------------------------------------------------------ # File format parsing and serialization # ------------------------------------------------------------------ class PubkeyParseError(Exception): pass def parse_pubkey(pubkey): # Strip trailing newlines while pubkey[-1:] in (b'\r', b'\n'): pubkey = pubkey[:-1] fields = pubkey.split(b' ') # There should be no newlines after this if b'\n' in pubkey: raise PubkeyParseError # algorithm keymaterial [comment] if len(fields) < 2: raise PubkeyParseError algorithm, keymaterial, *comment = fields if len(comment) == 0: comment = None else: comment = b''.join(comment) return algorithm, keymaterial, comment def serialize_known_hosts(hostname, port, algorithm, keymaterial): if port is None: return b'%s %s %s\n' % (hostname, algorithm, keymaterial) else: return b'[%s]:%i %s %s\n' % (hostname, port, algorithm, keymaterial) def serialize_authorized_keys(algorithm, keymaterial, comment): if comment is None: return b'%s %s\n' % (algorithm, keymaterial) else: return b'%s %s %s\n' % (algorithm, keymaterial, comment) # ------------------------------------------------------------------ # UI # ------------------------------------------------------------------ def usage(part = None): if part == 'client' or part is None: print('Usage: %s client [-p ] [-i ] [-o ] [-P ] ' % os.path.basename(sys.argv[0]), file = sys.stderr) if part =='server' or part is None: print('Usage: %s server [-p ] [-i ] [-o ]' % os.path.basename(sys.argv[0]), file = sys.stderr) sys.exit(1) def verify(client_pubkey, server_pubkey): own_hash = auth_hash(client_pubkey, server_pubkey) print('Authentication hash: %s' % format_hash(own_hash)) while True: user_input = input('Do the hashes match (yes/no/[paste])? ') if user_input == 'no': error('Could not transfer the keys') elif user_input == 'yes': return else: try: other_hash = parse_hash(user_input) except HashFormatError: print('Expected \'yes\', \'no\' or a base32-encoded hash') continue except HashChecksumError: print('Hash checksum check failed') continue if secrets.compare_digest(own_hash, other_hash): print('Hash matches. You can now type \'yes\' on the other end') return else: error('Could not transfer the keys') def main(): if len(sys.argv) < 2: usage() command = sys.argv[1] opts, fixed = getopt.gnu_getopt(sys.argv[2:], 'p:i:o:P:') port = 16889 pubkey_file = None output_file = None ssh_port = None for switch, arg in opts: if switch == '-p': try: port = int(arg) except ValueError: error('Port needs to be a number') elif switch == '-i': pubkey_file = arg elif switch == '-o': output_file = arg elif switch == '-P': try: ssh_port = int(arg) except ValueError: error('Port needs to be a number') if command == 'server': if len(fixed) != 0 or ssh_port is not None: usage('server') if pubkey_file is None: # Try the default ssh host key location server_pubkey = None for algorithm in ['ed25519', 'ecdsa', 'rsa']: pubkey_file = '/etc/ssh/ssh_host_' + algorithm + '_key.pub' try: with open(pubkey_file, 'rb') as f: server_pubkey = f.read() break except IOError: continue if server_pubkey is None: error('Could not find server public key (tried /etc/ssh/ssh_host_{ed25519,ecdsa,rsa}_key.pub)') else: try: with open(pubkey_file, 'rb') as f: server_pubkey = f.read() except IOError as err: error('Could not read server public key: %s' % err) try: parse_pubkey(server_pubkey) except PubkeyParseError: error('Public key is in an unrecognized format') client_pubkey = server(server_pubkey, port) verify(client_pubkey, server_pubkey) try: algorithm, keymaterial, comment = parse_pubkey(client_pubkey) except PubkeyParseError: error('Parse error on client\'s pubkey') authorized_keys_entry = serialize_authorized_keys(algorithm, keymaterial, comment) if output_file is None: # Try ~/.ssh/authorized_keys if 'HOME' not in os.environ: error('Cannot locate homedir, $HOME is not set') output_file = os.environ['HOME'] + '/.ssh/authorized_keys' try: with open(output_file, 'ab') as f: f.write(authorized_keys_entry) except IOError as err: error('Could not write authorized_keys entry: %s' % err) elif command == 'client': if len(fixed) != 1: usage('client') if pubkey_file is None: # Try the default ssh client key location client_pubkey = None if 'HOME' not in os.environ: error('Cannot locate homedir, $HOME is not set') for algorithm in ['ed25519', 'ecdsa', 'rsa']: pubkey_file = os.environ['HOME'] +'/.ssh/id_' + algorithm + '.pub' try: with open(pubkey_file, 'rb') as f: client_pubkey = f.read() break except IOError: continue if client_pubkey is None: error('Could not find client public key (tried ~/.ssh/id_{ed25519,ecdsa,rsa}.pub)') else: try: with open(pubkey_file, 'rb') as f: client_pubkey = f.read() except IOError as err: error('Could not read client public key: %s' % err) try: parse_pubkey(client_pubkey) except PubkeyParseError: error('Public key is in an unrecognized format') host, = fixed # Support internationalized domain names host = host.encode('idna').decode() server_pubkey = client(client_pubkey, host, port) verify(client_pubkey, server_pubkey) try: algorithm, keymaterial, comment = parse_pubkey(server_pubkey) except PubkeyParseError: error('Parse error on server\'s pubkey') known_hosts_entry = serialize_known_hosts(host.encode(), ssh_port, algorithm, keymaterial) if output_file is None: # Try ~/.ssh/known_hosts if 'HOME' not in os.environ: error('Cannot locate homedir, $HOME is not set') output_file = os.environ['HOME'] + '/.ssh/known_hosts' try: with open(output_file, 'ab') as f: f.write(known_hosts_entry) except IOError as err: error('Could not write known_hosts entry: %s' % err) else: usage() if __name__ == '__main__': main()