kishib/kishib.py

408 lines
11 KiB
Python

import base64
import binascii
import getopt
import hashlib
import os
import secrets
import socket
import sys
def error(*args):
print('%s: Error:' % os.path.basename(sys.argv[0]), *args, file = sys.stderr)
sys.exit(1)
# ------------------------------------------------------------------
# Client - server communication
# ------------------------------------------------------------------
def enc_len(length):
assert 0 <= length < 1<<16
high_byte = length >> 8
low_byte = length & 0xff
return bytes([low_byte, high_byte])
def dec_len(encoded):
low_byte, high_byte = encoded
return (high_byte << 8) | low_byte
def server(server_pubkey, port):
sock = None
for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
af, socktype, proto, canonname, sa = res
try:
sock = socket.socket(af, socktype, proto)
except OSError:
sock = None
continue
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(sa)
sock.listen(1)
except OSError:
sock.close()
sock = None
continue
break
if sock is None:
error('Could not bind on port %i' % port)
conn, addr = sock.accept()
with conn:
remote_host, remote_port = addr
print('Connection from %s' % remote_host)
client_pubkey_len = dec_len(conn.recv(2))
client_pubkey = bytearray()
while len(client_pubkey) < client_pubkey_len:
data = conn.recv(1024)
if not data:
sock.close()
error('Client public key could not be received')
client_pubkey.extend(data)
conn.sendall(enc_len(len(server_pubkey)))
conn.sendall(server_pubkey)
conn.shutdown(socket.SHUT_RDWR)
sock.close()
return bytes(client_pubkey)
def client(client_pubkey, host, port):
sock = None
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
sock = socket.socket(af, socktype, proto)
except OSError:
sock = None
continue
try:
sock.connect(sa)
except OSError:
sock.close()
sock = None
continue
break
if sock is None:
error('Could not connect to %s on port %i' % (host, port))
print('Connected to %s' % host)
with sock:
sock.sendall(enc_len(len(client_pubkey)))
sock.sendall(client_pubkey)
server_pubkey_len = dec_len(sock.recv(2))
server_pubkey = bytearray()
while len(server_pubkey) < server_pubkey_len:
data = sock.recv(1024)
if not data:
error('Server public key could not be received')
server_pubkey.extend(data)
sock.shutdown(socket.SHUT_RDWR)
return bytes(server_pubkey)
# ------------------------------------------------------------------
# Hash
# ------------------------------------------------------------------
def sha512(x):
h = hashlib.sha512()
h.update(x)
return h.digest()
def auth_hash(client_pubkey, server_pubkey):
client_pubkey_hash = sha512(enc_len(len(client_pubkey)) + client_pubkey)
server_pubkey_hash = sha512(enc_len(len(server_pubkey)) + server_pubkey)
combined_hash = sha512(client_pubkey_hash + server_pubkey_hash)
truncated_hash = combined_hash[:16]
hash_check = sha512(truncated_hash)[:4]
return truncated_hash + hash_check
def chunk(sliceable, length):
for i in range(0, len(sliceable), length):
yield sliceable[i:i + length]
def b32_encode(binary):
# Kishib uses a modified base32 compared to RFC 4648
# * All letters are lowercase
# * 'i' is replaced with '8' and 'o' is replaced with '9'
encoded = base64.b32encode(binary).decode()
encoded = encoded.lower()
encoded = encoded.replace('i', '8').replace('o', '9')
return encoded
def format_hash(hash_bytes):
hash_base32 = b32_encode(hash_bytes)
chunked_base32 = chunk(hash_base32, 4)
return '-'.join(chunked_base32)
class HashFormatError(Exception): pass
class HashChecksumError(Exception): pass
def parse_hash(auth_hash):
# Hash consists of 8 segments of four characters
segments = auth_hash.split('-')
if len(segments) != 8:
raise HashFormatError
if not all(len(i) == 4 for i in segments):
raise HashFormatError
combined = ''.join(segments)
if 'i' in combined or 'o' in combined:
raise HashFormatError
standard_base32 = combined.replace('8', 'i').replace('9', 'o').upper()
try:
binary = base64.b32decode(standard_base32)
except binascii.Error as err:
raise HashFormatError from err
truncated_hash = binary[:16]
hash_check = binary[16:]
# Using secrets.compare_digest is not necessary, but I feel it is a
# good habit to avoid comparing hashes with variable-timed comparisons
if not secrets.compare_digest(sha512(truncated_hash)[:4], hash_check):
raise HashChecksumError
return binary
# ------------------------------------------------------------------
# File format parsing and serialization
# ------------------------------------------------------------------
class PubkeyParseError(Exception): pass
def parse_pubkey(pubkey):
# Strip trailing newlines
while pubkey[-1:] in (b'\r', b'\n'):
pubkey = pubkey[:-1]
fields = pubkey.split(b' ')
# There should be no newlines after this
if b'\n' in pubkey:
raise PubkeyParseError
# algorithm keymaterial [comment]
if len(fields) < 2:
raise PubkeyParseError
algorithm, keymaterial, *comment = fields
if len(comment) == 0:
comment = None
else:
comment = b''.join(comment)
return algorithm, keymaterial, comment
def serialize_known_hosts(hostname, port, algorithm, keymaterial):
if port is None:
return b'%s %s %s\n' % (hostname, algorithm, keymaterial)
else:
return b'[%s]:%i %s %s\n' % (hostname, port, algorithm, keymaterial)
def serialize_authorized_keys(algorithm, keymaterial, comment):
if comment is None:
return b'%s %s\n' % (algorithm, keymaterial)
else:
return b'%s %s %s\n' % (algorithm, keymaterial, comment)
# ------------------------------------------------------------------
# UI
# ------------------------------------------------------------------
def usage(part = None):
if part == 'client' or part is None:
print('Usage: %s client [-p <port>] [-i <pubkey>] [-o <known_hosts>] [-P <ssh port>] <host>' % os.path.basename(sys.argv[0]), file = sys.stderr)
if part =='server' or part is None:
print('Usage: %s server [-p <port>] [-i <pubkey>] [-o <authorized_keys>]' % os.path.basename(sys.argv[0]), file = sys.stderr)
sys.exit(1)
def verify(client_pubkey, server_pubkey):
own_hash = auth_hash(client_pubkey, server_pubkey)
print('Authentication hash: %s' % format_hash(own_hash))
while True:
user_input = input('Do the hashes match (yes/no/[paste])? ')
if user_input == 'no':
error('Could not transfer the keys')
elif user_input == 'yes':
return
else:
try:
other_hash = parse_hash(user_input)
except HashFormatError:
print('Expected \'yes\', \'no\' or a base32-encoded hash')
continue
except HashChecksumError:
print('Hash checksum check failed')
continue
if secrets.compare_digest(own_hash, other_hash):
print('Hash matches. You can now type \'yes\' on the other end')
return
else:
error('Could not transfer the keys')
def main():
if len(sys.argv) < 2:
usage()
command = sys.argv[1]
opts, fixed = getopt.gnu_getopt(sys.argv[2:], 'p:i:o:P:')
port = 16889
pubkey_file = None
output_file = None
ssh_port = None
for switch, arg in opts:
if switch == '-p':
try:
port = int(arg)
except ValueError:
error('Port needs to be a number')
elif switch == '-i':
pubkey_file = arg
elif switch == '-o':
output_file = arg
elif switch == '-P':
try:
ssh_port = int(arg)
except ValueError:
error('Port needs to be a number')
if command == 'server':
if len(fixed) != 0 or ssh_port is not None:
usage('server')
if pubkey_file is None:
# Try the default ssh host key location
server_pubkey = None
for algorithm in ['ed25519', 'ecdsa', 'rsa']:
pubkey_file = '/etc/ssh/ssh_host_' + algorithm + '_key.pub'
try:
with open(pubkey_file, 'rb') as f:
server_pubkey = f.read()
break
except IOError:
continue
if server_pubkey is None:
error('Could not find server public key (tried /etc/ssh/ssh_host_{ed25519,ecdsa,rsa}_key.pub)')
else:
try:
with open(pubkey_file, 'rb') as f:
server_pubkey = f.read()
except IOError as err:
error('Could not read server public key: %s' % err)
try:
parse_pubkey(server_pubkey)
except PubkeyParseError:
error('Public key is in an unrecognized format')
client_pubkey = server(server_pubkey, port)
verify(client_pubkey, server_pubkey)
try:
algorithm, keymaterial, comment = parse_pubkey(client_pubkey)
except PubkeyParseError:
error('Parse error on client\'s pubkey')
authorized_keys_entry = serialize_authorized_keys(algorithm, keymaterial, comment)
if output_file is None:
# Try ~/.ssh/authorized_keys
if 'HOME' not in os.environ:
error('Cannot locate homedir, $HOME is not set')
output_file = os.environ['HOME'] + '/.ssh/authorized_keys'
try:
with open(output_file, 'ab') as f:
f.write(authorized_keys_entry)
except IOError as err:
error('Could not write authorized_keys entry: %s' % err)
elif command == 'client':
if len(fixed) != 1:
usage('client')
if pubkey_file is None:
# Try the default ssh client key location
client_pubkey = None
if 'HOME' not in os.environ:
error('Cannot locate homedir, $HOME is not set')
for algorithm in ['ed25519', 'ecdsa', 'rsa']:
pubkey_file = os.environ['HOME'] +'/.ssh/id_' + algorithm + '.pub'
try:
with open(pubkey_file, 'rb') as f:
client_pubkey = f.read()
break
except IOError:
continue
if client_pubkey is None:
error('Could not find client public key (tried ~/.ssh/id_{ed25519,ecdsa,rsa}.pub)')
else:
try:
with open(pubkey_file, 'rb') as f:
client_pubkey = f.read()
except IOError as err:
error('Could not read client public key: %s' % err)
try:
parse_pubkey(client_pubkey)
except PubkeyParseError:
error('Public key is in an unrecognized format')
host, = fixed
# Support internationalized domain names
host = host.encode('idna').decode()
server_pubkey = client(client_pubkey, host, port)
verify(client_pubkey, server_pubkey)
try:
algorithm, keymaterial, comment = parse_pubkey(server_pubkey)
except PubkeyParseError:
error('Parse error on server\'s pubkey')
known_hosts_entry = serialize_known_hosts(host.encode(), ssh_port, algorithm, keymaterial)
if output_file is None:
# Try ~/.ssh/known_hosts
if 'HOME' not in os.environ:
error('Cannot locate homedir, $HOME is not set')
output_file = os.environ['HOME'] + '/.ssh/known_hosts'
try:
with open(output_file, 'ab') as f:
f.write(known_hosts_entry)
except IOError as err:
error('Could not write known_hosts entry: %s' % err)
else:
usage()
if __name__ == '__main__':
main()