Use channels for internal communication

This commit is contained in:
Juhani Haverinen 2017-09-05 10:55:33 +03:00
parent 4e1efa5b61
commit 0657f423f3
2 changed files with 88 additions and 38 deletions

51
channel.py Normal file
View File

@ -0,0 +1,51 @@
import select
import socket
import threading
class Channel:
"""An asynchronic communication channel that can be used to send python object and can be poll()ed."""
def __init__(self):
# We use a socket to enable polling and blocking reads
self.write_socket, self.read_socket = socket.socketpair()
self.poll = select.poll()
self.poll.register(self.read_socket, select.POLLIN)
# Store messages in a list
self.mesages = []
self.messages_lock = threading.Lock()
def send(self, message):
# Add message to the list of messages and write to the write socket to signal there's data to read
with self.messages_lock:
self.write_socket.sendall(b'!')
self.mesages.append(message)
def recv(self, blocking = True):
# Timeout of -1 will make poll wait until data is available
# Timeout of 0 will make poll exit immediately if there's no data
if blocking:
timeout = -1
else:
timeout = 0
# See if there is data to read / wait until there is
results = self.poll.poll(timeout)
# None of the sockets were ready. This can only happen if we weren't blocking
# Return None to signal lack of data
if len(results) == 0:
assert not blocking
return None
# Remove first message from the list (FIFO principle), and read one byte from the socket
# This keeps the number of available messages and the number of bytes readable in the socket in sync
with self.messages_lock:
message = self.mesages.pop(0)
self.read_socket.recv(1)
return message
def fileno(self):
# Allows for a Channel object to be passed directly to poll()
return self.read_socket.fileno()

View File

@ -4,14 +4,17 @@ import socket
import threading import threading
from collections import namedtuple from collections import namedtuple
import channel
Server = namedtuple('Server', ['host', 'port']) Server = namedtuple('Server', ['host', 'port'])
# ServerThread(server, control_socket) # ServerThread(server, control_socket)
# Creates a new server main loop thread # Creates a new server main loop thread
class ServerThread(threading.Thread): class ServerThread(threading.Thread):
def __init__(self, server, control_socket): def __init__(self, server, control_channel, logging_channel):
self.server = server self.server = server
self.control_socket = control_socket self.control_channel = control_channel
self.logging_channel = logging_channel
self.server_socket_write_lock = threading.Lock() self.server_socket_write_lock = threading.Lock()
@ -24,23 +27,22 @@ class ServerThread(threading.Thread):
with self.server_socket_write_lock: with self.server_socket_write_lock:
self.server_socket.sendall(line + b'\r\n') self.server_socket.sendall(line + b'\r\n')
# FIXME: print is not thread safe # FIXME: use a real data structure
print('>' + line.decode(encoding = 'utf-8', errors = 'replace')) self.logging_channel.send('>' + line.decode(encoding = 'utf-8', errors = 'replace'))
def handle_line(self, line): def handle_line(self, line):
# TODO: implement line handling # TODO: implement line handling
# FIXME: print is not thread safe # FIXME: use a real data structure
print('<' + line.decode(encoding = 'utf-8', errors = 'replace')) self.logging_channel.send('<' + line.decode(encoding = 'utf-8', errors = 'replace'))
def mainloop(self): def mainloop(self):
# Register both the server and the control socket to our polling object # Register both the server socket and the control channel to or polling object
poll = select.poll() poll = select.poll()
poll.register(self.server_socket, select.POLLIN) poll.register(self.server_socket, select.POLLIN)
poll.register(self.control_socket, select.POLLIN) poll.register(self.control_channel, select.POLLIN)
# Keep buffers for input and output # Keep buffer for input
server_input_buffer = bytearray() server_input_buffer = bytearray()
control_input_buffer = bytearray()
quitting = False quitting = False
while not quitting: while not quitting:
@ -62,15 +64,21 @@ class ServerThread(threading.Thread):
self.handle_line(line) self.handle_line(line)
# Control # Control
elif fd == self.control_socket.fileno(): elif fd == self.control_channel.fileno():
# Read into buffer and handle full commands command = self.control_channel.recv()
data = self.control_socket.recv(1024)
control_input_buffer.extend(data)
# TODO: implement command handling # FIXME: use a real data structure
if len(control_input_buffer) > 1: if command == 'q':
quitting = True quitting = True
elif len(command) > 0 and command[0] == '/':
irc_command, space, arguments = command[1:].encode('utf-8').partition(b' ')
line = irc_command.upper() + space + arguments
self.send_line_raw(line)
else:
self.logging_channel.send('?')
else: else:
assert False #unreachable assert False #unreachable
@ -81,8 +89,8 @@ class ServerThread(threading.Thread):
self.server_socket = socket.create_connection(address) self.server_socket = socket.create_connection(address)
except ConnectionRefusedError: except ConnectionRefusedError:
# Tell controller we failed # Tell controller we failed
self.control_socket.sendall(b'F') self.logging_channel.send('f')
self.control_socket.close() return
# Run initialization # Run initialization
# TODO: read nick/username/etc. from a config # TODO: read nick/username/etc. from a config
@ -97,37 +105,28 @@ class ServerThread(threading.Thread):
self.server_socket.close() self.server_socket.close()
# Tell controller we're quiting # Tell controller we're quiting
self.control_socket.sendall(b'Q' + b'\n') self.logging_channel.send('q')
self.control_socket.close()
# spawn_serverthread(server) → control_socket # spawn_serverthread(server) → control_channel, logging_channel
# Creates a ServerThread for given server and returns the socket for controlling it # Creates a ServerThread for given server and returns the channels for controlling and monitoring it
def spawn_serverthread(server): def spawn_serverthread(server):
thread_control_socket, spawner_control_socket = socket.socketpair() thread_control_socket, spawner_control_socket = socket.socketpair()
ServerThread(server, thread_control_socket).start() control_channel = channel.Channel()
return spawner_control_socket logging_channel = channel.Channel()
ServerThread(server, control_channel, logging_channel).start()
return (control_channel, logging_channel)
if __name__ == '__main__': if __name__ == '__main__':
control_socket = spawn_serverthread(Server('irc.freenode.net', 6667)) control_channel, logging_channel = spawn_serverthread(Server('irc.freenode.net', 6667))
while True: while True:
cmd = input(': ') cmd = input(': ')
if cmd == '': if cmd == '':
control_messages = bytearray() print(logging_channel.recv(blocking = False))
while True:
data = control_socket.recv(1024)
if not data:
break
control_messages.extend(data)
print(control_messages.decode(encoding = 'utf-8', errors = 'replace'))
elif cmd == 'q': elif cmd == 'q':
control_socket.sendall(b'Q\n') control_channel.send('q')
control_socket.close()
break break
else: else:
control_socket.sendall(cmd.encode('utf-8') + b'\n') control_channel.send(cmd)