263 lines
6.5 KiB
Python
263 lines
6.5 KiB
Python
import base64
|
|
import binascii
|
|
import getopt
|
|
import hashlib
|
|
import os.path
|
|
import secrets
|
|
import socket
|
|
import sys
|
|
|
|
def error(*args):
|
|
print('%s: Error:' % os.path.basename(sys.argv[0]), *args, file = sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Client - server communication
|
|
# ------------------------------------------------------------------
|
|
|
|
def enc_len(length):
|
|
assert 0 <= length < 1<<16
|
|
high_byte = length >> 8
|
|
low_byte = length & 0xff
|
|
return bytes([low_byte, high_byte])
|
|
|
|
def dec_len(encoded):
|
|
low_byte, high_byte = encoded
|
|
return (high_byte << 8) | low_byte
|
|
|
|
def server(server_pubkey, port):
|
|
sock = None
|
|
for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
|
|
af, socktype, proto, canonname, sa = res
|
|
try:
|
|
sock = socket.socket(af, socktype, proto)
|
|
except OSError:
|
|
sock = None
|
|
continue
|
|
try:
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(sa)
|
|
sock.listen(1)
|
|
except OSError:
|
|
sock.close()
|
|
sock = None
|
|
continue
|
|
break
|
|
|
|
if sock is None:
|
|
error('Could not bind on port %i' % port)
|
|
|
|
conn, addr = sock.accept()
|
|
with conn:
|
|
remote_host, remote_port = addr
|
|
print('Connection from %s' % remote_host)
|
|
|
|
client_pubkey_len = dec_len(conn.recv(2))
|
|
client_pubkey = bytearray()
|
|
while len(client_pubkey) < client_pubkey_len:
|
|
data = conn.recv(1024)
|
|
if not data:
|
|
sock.close()
|
|
error('Client public key could not be received')
|
|
client_pubkey.extend(data)
|
|
|
|
conn.sendall(enc_len(len(server_pubkey)))
|
|
conn.sendall(server_pubkey)
|
|
|
|
conn.shutdown(socket.SHUT_RDWR)
|
|
|
|
sock.close()
|
|
|
|
return bytes(client_pubkey)
|
|
|
|
def client(client_pubkey, host, port):
|
|
sock = None
|
|
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
|
|
af, socktype, proto, canonname, sa = res
|
|
try:
|
|
sock = socket.socket(af, socktype, proto)
|
|
except OSError:
|
|
sock = None
|
|
continue
|
|
try:
|
|
sock.connect(sa)
|
|
except OSError:
|
|
sock.close()
|
|
sock = None
|
|
continue
|
|
break
|
|
|
|
if sock is None:
|
|
error('Could not connect to %s on port %i' % (host, port))
|
|
|
|
print('Connected to %s' % host)
|
|
|
|
with sock:
|
|
sock.sendall(enc_len(len(client_pubkey)))
|
|
sock.sendall(client_pubkey)
|
|
|
|
server_pubkey_len = dec_len(sock.recv(2))
|
|
server_pubkey = bytearray()
|
|
while len(server_pubkey) < server_pubkey_len:
|
|
data = sock.recv(1024)
|
|
if not data:
|
|
error('Server public key could not be received')
|
|
server_pubkey.extend(data)
|
|
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
|
|
return bytes(server_pubkey)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Hash
|
|
# ------------------------------------------------------------------
|
|
|
|
def sha512(x):
|
|
h = hashlib.sha512()
|
|
h.update(x)
|
|
return h.digest()
|
|
|
|
def auth_hash(client_pubkey, server_pubkey):
|
|
client_pubkey_hash = sha512(enc_len(len(client_pubkey)) + client_pubkey)
|
|
server_pubkey_hash = sha512(enc_len(len(server_pubkey)) + server_pubkey)
|
|
combined_hash = sha512(client_pubkey_hash + server_pubkey_hash)
|
|
|
|
truncated_hash = combined_hash[:16]
|
|
hash_check = sha512(truncated_hash)[:4]
|
|
return truncated_hash + hash_check
|
|
|
|
def chunk(sliceable, length):
|
|
for i in range(0, len(sliceable), length):
|
|
yield sliceable[i:i + length]
|
|
|
|
def b32_encode(binary):
|
|
# Kishib uses a modified base32 compared to RFC 4648
|
|
# * All letters are lowercase
|
|
# * 'i' is replaced with '8' and 'o' is replaced with '9'
|
|
encoded = base64.b32encode(binary).decode()
|
|
encoded = encoded.lower()
|
|
encoded = encoded.replace('i', '8').replace('o', '9')
|
|
return encoded
|
|
|
|
def format_hash(hash_bytes):
|
|
hash_base32 = b32_encode(hash_bytes)
|
|
chunked_base32 = chunk(hash_base32, 4)
|
|
return '-'.join(chunked_base32)
|
|
|
|
class HashFormatError(Exception): pass
|
|
class HashChecksumError(Exception): pass
|
|
|
|
def parse_hash(auth_hash):
|
|
# Hash consists of 8 segments of four characters
|
|
segments = auth_hash.split('-')
|
|
if len(segments) != 8:
|
|
raise HashFormatError
|
|
if not all(len(i) == 4 for i in segments):
|
|
raise HashFormatError
|
|
|
|
combined = ''.join(segments)
|
|
if 'i' in combined or 'o' in combined:
|
|
raise HashFormatError
|
|
|
|
standard_base32 = combined.replace('8', 'i').replace('9', 'o').upper()
|
|
|
|
try:
|
|
binary = base64.b32decode(standard_base32)
|
|
except binascii.Error as err:
|
|
raise HashFormatError from err
|
|
|
|
truncated_hash = binary[:16]
|
|
hash_check = binary[16:]
|
|
|
|
# Using secrets.compare_digest is not necessary, but I feel it is a
|
|
# good habit to avoid comparing hashes with variable-timed comparisons
|
|
if not secrets.compare_digest(sha512(truncated_hash)[:4], hash_check):
|
|
raise HashChecksumError
|
|
|
|
return binary
|
|
|
|
# ------------------------------------------------------------------
|
|
# UI
|
|
# ------------------------------------------------------------------
|
|
|
|
def usage(part = None):
|
|
if part == 'client' or part is None:
|
|
print('Usage: %s client [-p <port>] <host>' % os.path.basename(sys.argv[0]), file = sys.stderr)
|
|
if part =='server' or part is None:
|
|
print('Usage: %s server [-p <port>]' % os.path.basename(sys.argv[0]), file = sys.stderr)
|
|
sys.exit(1)
|
|
|
|
def verify(client_pubkey, server_pubkey):
|
|
own_hash = auth_hash(client_pubkey, server_pubkey)
|
|
print('Authentication hash: %s' % format_hash(own_hash))
|
|
|
|
while True:
|
|
user_input = input('Do the hashes match (yes/no/[paste])? ')
|
|
|
|
if user_input == 'no':
|
|
error('Could not transfer the keys')
|
|
|
|
elif user_input == 'yes':
|
|
return
|
|
|
|
else:
|
|
try:
|
|
other_hash = parse_hash(user_input)
|
|
except HashFormatError:
|
|
print('Expected \'yes\', \'no\' or a base32-encoded hash')
|
|
continue
|
|
except HashChecksumError:
|
|
print('Hash checksum check failed')
|
|
continue
|
|
|
|
if secrets.compare_digest(own_hash, other_hash):
|
|
print('Hash matches. You can now type \'yes\' on the other end')
|
|
return
|
|
else:
|
|
error('Could not transfer the keys')
|
|
|
|
def main():
|
|
# TODO: Read pubkeys from files
|
|
# TODO: Write pubkeys to files
|
|
if len(sys.argv) < 2:
|
|
usage()
|
|
|
|
command = sys.argv[1]
|
|
opts, fixed = getopt.gnu_getopt(sys.argv[2:], 'p:')
|
|
|
|
# TODO: Select an actual port
|
|
port = 1234
|
|
for switch, arg in opts:
|
|
if switch == '-p':
|
|
try:
|
|
port = int(arg)
|
|
except ValueError:
|
|
error('Port needs to be a number')
|
|
|
|
if command == 'server':
|
|
if len(fixed) != 0:
|
|
usage('server')
|
|
|
|
server_pubkey = b'server'
|
|
|
|
client_pubkey = server(server_pubkey, port)
|
|
|
|
verify(client_pubkey, server_pubkey)
|
|
|
|
elif command == 'client':
|
|
if len(fixed) != 1:
|
|
usage('client')
|
|
|
|
host, = fixed
|
|
client_pubkey = b'client'
|
|
|
|
server_pubkey = client(client_pubkey, host, port)
|
|
|
|
verify(client_pubkey, server_pubkey)
|
|
|
|
else:
|
|
usage()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|