tea_cah/ircbot.py

210 lines
7.0 KiB
Python

#!/usr/bin/env python3
import select
import socket
import threading
from collections import namedtuple
import channel
from constants import logmessage_types, internal_submessage_types, controlmessage_types
import line_handling
Server = namedtuple('Server', ['host', 'port', 'nick', 'realname', 'channels'])
# API(serverthread_object)
# Create a new API object corresponding to given ServerThread object
class API:
def __init__(self, serverthread_object):
# We need to access the internal functions of the ServerThread object in order to send lines etc.
self.serverthread_object = serverthread_object
def send_raw(self, line):
self.serverthread_object.send_line_raw(line)
def msg(self, recipient, message):
"""Make sending PRIVMSGs much nicer"""
line = b'PRIVMSG ' + recipient + b' :' + message
self.serverthread_object.send_line_raw(line)
def nick(self, nick):
# Send a NICK command and update the internal nick tracking state
with self.serverthread_object.nick_lock:
line = b'NICK ' + nick
self.serverthread_object.send_line_raw(line)
self.serverthread_object.nick = nick
def join(self, channel):
# Send a JOIN command and update the internal channel tracking state
with self.serverthread_object.channels_lock:
line = b'JOIN ' + channel
self.serverthread_object.send_line_raw(line)
self.serverthread_object.channels.add(channel)
def error(self, message):
self.serverthread_object.logging_channel.send((constants.logmessage_types.internal, constants.internal_submessage_types.error, message))
# ServerThread(server, control_socket)
# Creates a new server main loop thread
class ServerThread(threading.Thread):
def __init__(self, server, control_channel, logging_channel):
self.server = server
self.control_channel = control_channel
self.logging_channel = logging_channel
self.server_socket_write_lock = threading.Lock()
self.nick = None
self.nick_lock = threading.Lock()
self.channels = set()
self.channels_lock = threading.Lock()
threading.Thread.__init__(self)
def send_line_raw(self, line):
# Sanitize line just in case
line = line.replace(b'\r', b'').replace(b'\n', b'')[:510]
with self.server_socket_write_lock:
self.server_socket.sendall(line + b'\r\n')
# Don't log PONGs
if not (len(line) >= 5 and line[:5] == b'PONG '):
self.logging_channel.send((logmessage_types.sent, line.decode(encoding = 'utf-8', errors = 'replace')))
def handle_line(self, line):
command, _, arguments = line.partition(b' ')
if command.upper() == b'PING':
self.send_line_raw(b'PONG ' + arguments)
else:
self.logging_channel.send((logmessage_types.received, line.decode(encoding = 'utf-8', errors = 'replace')))
line_handling.handle_line(line, irc = self.api)
def mainloop(self):
# Register both the server socket and the control channel to or polling object
poll = select.poll()
poll.register(self.server_socket, select.POLLIN)
poll.register(self.control_channel, select.POLLIN)
# Keep buffer for input
server_input_buffer = bytearray()
# TODO: Implement timeouting
quitting = False
while not quitting:
# Wait until we can do something
for fd, event in poll.poll():
# Server
if fd == self.server_socket.fileno():
# Ready to receive, read into buffer and handle full messages
if event | select.POLLIN:
data = self.server_socket.recv(1024)
server_input_buffer.extend(data)
# Try to see if we have a full line ending with \r\n in the buffer
# If yes, handle it
while b'\r\n' in server_input_buffer:
# Newline was found, split buffer
line, _, server_input_buffer = server_input_buffer.partition(b'\r\n')
self.handle_line(line)
# Control
elif fd == self.control_channel.fileno():
command_type, *arguments = self.control_channel.recv()
if command_type == controlmessage_types.quit:
quitting = True
elif command_type == controlmessage_types.send_line:
assert len(arguments) == 1
irc_command, space, arguments = arguments[0].encode('utf-8').partition(b' ')
line = irc_command.upper() + space + arguments
self.send_line_raw(line)
else:
error_message = 'Unknown control message: %s' % repr((command_type, *arguments))
self.logging_channel.send((logmessage_types.internal, internal_submessage_types.error, error_message))
else:
assert False #unreachable
def run(self):
# Connect to given server
address = (self.server.host, self.server.port)
try:
self.server_socket = socket.create_connection(address)
except ConnectionRefusedError:
# Tell controller we failed
self.logging_channel.send((logmessage_types.internal, internal_submessage_types.error, "Can't connect to %s:%s" % address))
self.logging_channel.send((logmessage_types.internal, internal_submessage_types.quit))
return
# Create an API object to give to outside line handler
self.api = API(self)
# Run initialization
self.send_line_raw(b'USER HynneFlip a a :' + self.server.realname.encode('utf-8'))
# Set up nick and channels
self.api.nick(self.server.nick.encode('utf-8'))
for channel in self.server.channels:
self.api.join(channel.encode('utf-8'))
# Run mainloop
self.mainloop()
# Tell the server we're quiting
self.send_line_raw(b'QUIT :HynneFlip exiting normally')
self.server_socket.close()
# Tell controller we're quiting
self.logging_channel.send((logmessage_types.internal, internal_submessage_types.quit))
# spawn_serverthread(server) → control_channel, logging_channel
# Creates a ServerThread for given server and returns the channels for controlling and monitoring it
def spawn_serverthread(server):
thread_control_socket, spawner_control_socket = socket.socketpair()
control_channel = channel.Channel()
logging_channel = channel.Channel()
ServerThread(server, control_channel, logging_channel).start()
return (control_channel, logging_channel)
if __name__ == '__main__':
server = Server(host = 'irc.freenode.net', port = 6667, nick = 'HynneFlip', realname = 'HynneFlip IRC bot', channels = ['##ingsoc'])
control_channel, logging_channel = spawn_serverthread(server)
while True:
cmd = input(': ')
if cmd == '':
while True:
data = logging_channel.recv(blocking = False)
if data == None:
break
message_type, *message_data = data
if message_type == logmessage_types.sent:
assert len(message_data) == 1
print('>' + message_data[0])
elif message_type == logmessage_types.received:
assert len(message_data) == 1
print('<' + message_data[0])
elif message_type == logmessage_types.internal:
if message_data[0] == internal_submessage_types.quit:
assert len(message_data) == 1
print('--- Quit')
elif message_data[0] == internal_submessage_types.error:
assert len(message_data) == 2
print('--- Error', message_data[1])
else:
print('--- ???', message_data)
else:
print('???', message_type, message_data)
elif cmd == 'q':
control_channel.send((controlmessage_types.quit,))
break
elif len(cmd) > 0 and cmd[0] == '/':
control_channel.send((controlmessage_types.send_line, cmd[1:]))