import base64 import hashlib import socket import sys def error(*args): print('%s: Error:' % sys.argv[0], *args, file = sys.stderr) sys.exit(1) def enc_len(length): assert 0 <= length < 1<<16 high_byte = length >> 8 low_byte = length & 0xff return bytes([low_byte, high_byte]) def dec_len(encoded): low_byte, high_byte = encoded return (high_byte << 8) | low_byte def server(server_pubkey, port): sock = None for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) except OSError: sock = None continue try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(sa) sock.listen(1) except OSError: sock.close() sock = None continue break if sock is None: error('Could not bind on port %i' % port) conn, addr = sock.accept() with conn: remote_host, remote_port = addr print('Connection from %s' % remote_host) client_pubkey_len = dec_len(conn.recv(2)) client_pubkey = bytearray() while len(client_pubkey) < client_pubkey_len: data = conn.recv(1024) if not data: sock.close() error('Client public key could not be received') client_pubkey.extend(data) conn.sendall(enc_len(len(server_pubkey))) conn.sendall(server_pubkey) conn.shutdown(socket.SHUT_RDWR) sock.close() return bytes(client_pubkey) def client(client_pubkey, host, port): sock = None for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: sock = socket.socket(af, socktype, proto) except OSError: sock = None continue try: sock.connect(sa) except OSError: sock.close() sock = None continue break if sock is None: error('Could not connect to %s on port %i' % (host, port)) print('Connected to %s' % host) with sock: sock.sendall(enc_len(len(client_pubkey))) sock.sendall(client_pubkey) server_pubkey_len = dec_len(sock.recv(2)) server_pubkey = bytearray() while len(server_pubkey) < server_pubkey_len: data = sock.recv(1024) if not data: error('Server public key could not be received') server_pubkey.extend(data) sock.shutdown(socket.SHUT_RDWR) return bytes(server_pubkey) def sha512(x): h = hashlib.sha512() h.update(x) return h.digest() def auth_hash(client_pubkey, server_pubkey): client_pubkey_hash = sha512(enc_len(len(client_pubkey)) + client_pubkey) server_pubkey_hash = sha512(enc_len(len(server_pubkey)) + server_pubkey) combined_hash = sha512(client_pubkey_hash + server_pubkey_hash) truncated_hash = combined_hash[:32] hash_check = sha512(truncated_hash)[:4] return truncated_hash + hash_check def chunk(sliceable, length): for i in range(0, len(sliceable), length): yield sliceable[i:i + length] def format_hash(hash_bytes): hash_base64 = base64.b64encode(hash_bytes).decode() chunked_base64 = chunk(hash_base64, 4) return '-'.join(chunked_base64) def verify(client_pubkey, server_pubkey): own_hash = auth_hash(client_pubkey, server_pubkey) print('Authentication hash: %s' % format_hash(own_hash)) # TODO: Actually verify def main(): # TODO: Actual agument parsing # TODO: Read pubkeys from files # TODO: Write pubkeys to files if sys.argv[1] == 'server': port = int(sys.argv[2]) server_pubkey = b'server\n' client_pubkey = server(server_pubkey, port) verify(client_pubkey, server_pubkey) elif sys.argv[1] == 'client': host = sys.argv[2] port = int(sys.argv[3]) client_pubkey = b'client\n' server_pubkey = client(client_pubkey, host, port) verify(client_pubkey, server_pubkey) else: print('Usage: %s server PORT') print('Usage: %s client HOST PORT') if __name__ == '__main__': main()