|
|
@@ -0,0 +1,375 @@ |
|
|
|
import aiohttp |
|
|
|
import aiohttp.web |
|
|
|
import asyncio |
|
|
|
import base64 |
|
|
|
import collections |
|
|
|
import concurrent.futures |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import signal |
|
|
|
import ssl |
|
|
|
import sys |
|
|
|
import toml |
|
|
|
import types |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') |
|
|
|
|
|
|
|
|
|
|
|
SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} |
|
|
|
|
|
|
|
|
|
|
|
class InvalidConfig(Exception): |
|
|
|
'''Error in configuration file''' |
|
|
|
|
|
|
|
|
|
|
|
def _mapping_to_namespace(d): |
|
|
|
'''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively''' |
|
|
|
return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()}) |
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
|
def __init__(self, filename): |
|
|
|
self._filename = filename |
|
|
|
# Set below: |
|
|
|
self.irc = None |
|
|
|
self.web = None |
|
|
|
self.maps = None |
|
|
|
|
|
|
|
with open(self._filename, 'r') as fp: |
|
|
|
obj = toml.load(fp) |
|
|
|
|
|
|
|
logging.info(repr(obj)) |
|
|
|
|
|
|
|
# Sanity checks |
|
|
|
if any(x not in ('irc', 'web', 'maps') for x in obj.keys()): |
|
|
|
raise InvalidConfig('Unknown sections found in base object') |
|
|
|
if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()): |
|
|
|
raise InvalidConfig('Invalid section type(s), expected objects/dicts') |
|
|
|
if 'irc' in obj: |
|
|
|
if any(x not in ('host', 'port', 'ssl', 'nick', 'real') for x in obj['irc']): |
|
|
|
raise InvalidConfig('Unknown key found in irc section') |
|
|
|
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname |
|
|
|
raise InvalidConfig('Invalid IRC host') |
|
|
|
if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535): |
|
|
|
raise InvalidConfig('Invalid IRC port') |
|
|
|
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): |
|
|
|
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') |
|
|
|
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname |
|
|
|
raise InvalidConfig('Invalid IRC nick') |
|
|
|
if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): |
|
|
|
raise InvalidConfig('Invalid IRC realname') |
|
|
|
if 'web' in obj: |
|
|
|
if any(x not in ('host', 'port') for x in obj['web']): |
|
|
|
raise InvalidConfig('Unknown key found in web section') |
|
|
|
if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?) |
|
|
|
raise InvalidConfig('Invalid web hostname') |
|
|
|
if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535): |
|
|
|
raise InvalidConfig('Invalid web port') |
|
|
|
if 'maps' in obj: |
|
|
|
for key, map_ in obj['maps'].items(): |
|
|
|
# Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace. |
|
|
|
#TODO: Support for fancier identifiers (PEP 3131)? |
|
|
|
if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '': |
|
|
|
raise InvalidConfig(f'Invalid map key {key!r}') |
|
|
|
if not isinstance(map_, collections.abc.Mapping): |
|
|
|
raise InvalidConfig(f'Invalid map for {key!r}') |
|
|
|
if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_): |
|
|
|
raise InvalidConfig(f'Unknown key(s) found in map {key!r}') |
|
|
|
#TODO: Check values |
|
|
|
|
|
|
|
# Default values |
|
|
|
self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.'}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}} |
|
|
|
|
|
|
|
# Fill in default values for the maps |
|
|
|
for key, map_ in obj['maps'].items(): |
|
|
|
if 'webpath' not in map_: |
|
|
|
map_['webpath'] = f'/{key}' |
|
|
|
if 'ircchannel' not in map_: |
|
|
|
map_['ircchannel'] = f'#{key}' |
|
|
|
if 'auth' not in map_: |
|
|
|
map_['auth'] = False |
|
|
|
|
|
|
|
# Merge in what was read from the config file and convert to SimpleNamespace |
|
|
|
for key in ('irc', 'web', 'maps'): |
|
|
|
if key in obj: |
|
|
|
self._obj[key].update(obj[key]) |
|
|
|
setattr(self, key, _mapping_to_namespace(self._obj[key])) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})' |
|
|
|
|
|
|
|
def reread(self): |
|
|
|
return Config(self._filename) |
|
|
|
|
|
|
|
|
|
|
|
class MessageQueue: |
|
|
|
# An object holding onto the messages received from nodeping |
|
|
|
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. |
|
|
|
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. |
|
|
|
# Differences to asyncio.Queue include: |
|
|
|
# - No maxsize |
|
|
|
# - No put coroutine (not necessary since the queue can never be full) |
|
|
|
# - Only one concurrent getter |
|
|
|
# - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._getter = None # None | asyncio.Future |
|
|
|
self._queue = collections.deque() |
|
|
|
|
|
|
|
async def get(self): |
|
|
|
if self._getter is not None: |
|
|
|
raise RuntimeError('Cannot get concurrently') |
|
|
|
if len(self._queue) == 0: |
|
|
|
self._getter = asyncio.get_running_loop().create_future() |
|
|
|
logging.debug('Awaiting getter') |
|
|
|
try: |
|
|
|
await self._getter |
|
|
|
except asyncio.CancelledError: |
|
|
|
logging.debug('Cancelled getter') |
|
|
|
self._getter = None |
|
|
|
raise |
|
|
|
logging.debug('Awaited getter') |
|
|
|
self._getter = None |
|
|
|
# For testing the cancellation/putting back onto the queue |
|
|
|
#logging.debug('Delaying message queue get') |
|
|
|
#await asyncio.sleep(3) |
|
|
|
#logging.debug('Done delaying') |
|
|
|
return self.get_nowait() |
|
|
|
|
|
|
|
def get_nowait(self): |
|
|
|
if len(self._queue) == 0: |
|
|
|
raise asyncio.QueueEmpty |
|
|
|
return self._queue.popleft() |
|
|
|
|
|
|
|
def put_nowait(self, item): |
|
|
|
self._queue.append(item) |
|
|
|
if self._getter is not None: |
|
|
|
self._getter.set_result(None) |
|
|
|
|
|
|
|
def putleft_nowait(self, item): |
|
|
|
self._queue.appendleft(item) |
|
|
|
if self._getter is not None: |
|
|
|
self._getter.set_result(None) |
|
|
|
|
|
|
|
def qsize(self): |
|
|
|
return len(self._queue) |
|
|
|
|
|
|
|
|
|
|
|
class IRCClientProtocol(asyncio.Protocol): |
|
|
|
def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): |
|
|
|
logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}') |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.connectionClosedEvent = connectionClosedEvent |
|
|
|
self.loop = loop |
|
|
|
self.config = config |
|
|
|
self.buffer = b'' |
|
|
|
self.connected = False |
|
|
|
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) |
|
|
|
|
|
|
|
def connection_made(self, transport): |
|
|
|
logging.info('Connected') |
|
|
|
self.transport = transport |
|
|
|
self.connected = True |
|
|
|
nickb = self.config.irc.nick.encode('utf-8') |
|
|
|
self.send(b'NICK ' + nickb) |
|
|
|
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8')) |
|
|
|
self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long |
|
|
|
asyncio.create_task(self.send_messages()) |
|
|
|
|
|
|
|
def update_channels(self, channels: set): |
|
|
|
channelsToPart = self.channels - channels |
|
|
|
channelsToJoin = channels - self.channels |
|
|
|
self.channels = channels |
|
|
|
|
|
|
|
if self.connected: |
|
|
|
if channelsToPart: |
|
|
|
#TODO: Split if too long |
|
|
|
self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8')) |
|
|
|
if channelsToJoin: |
|
|
|
self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8')) |
|
|
|
|
|
|
|
def send(self, data): |
|
|
|
logging.info(f'Send: {data!r}') |
|
|
|
self.transport.write(data + b'\r\n') |
|
|
|
|
|
|
|
async def _get_message(self): |
|
|
|
logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') |
|
|
|
messageFuture = asyncio.create_task(self.messageQueue.get()) |
|
|
|
done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
if self.connectionClosedEvent.is_set(): |
|
|
|
if messageFuture in pending: |
|
|
|
logging.debug('Cancelling messageFuture') |
|
|
|
messageFuture.cancel() |
|
|
|
try: |
|
|
|
await messageFuture |
|
|
|
except asyncio.CancelledError: |
|
|
|
logging.debug('Cancelled messageFuture') |
|
|
|
pass |
|
|
|
else: |
|
|
|
# messageFuture is already done but we're stopping, so put the result back onto the queue |
|
|
|
self.messageQueue.putleft_nowait(messageFuture.result()) |
|
|
|
return None, None |
|
|
|
assert messageFuture in done, 'Invalid state: messageFuture not in done futures' |
|
|
|
return messageFuture.result() |
|
|
|
|
|
|
|
async def send_messages(self): |
|
|
|
while self.connected: |
|
|
|
logging.debug(f'{id(self)}: trying to get a message') |
|
|
|
channel, message = await self._get_message() |
|
|
|
logging.debug(f'{id(self)}: got message: {message!r}') |
|
|
|
if message is None: |
|
|
|
break |
|
|
|
self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8')) |
|
|
|
#TODO self.messageQueue.putleft_nowait if delivery fails |
|
|
|
await asyncio.sleep(1) # Rate limit |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
logging.debug(f'Data received: {data!r}') |
|
|
|
# Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. |
|
|
|
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. |
|
|
|
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again. |
|
|
|
messages = data.split(b'\r\n') |
|
|
|
if self.buffer: |
|
|
|
self.message_received(self.buffer + messages[0]) |
|
|
|
messages = messages[1:] |
|
|
|
for message in messages[:-1]: |
|
|
|
self.message_received(message) |
|
|
|
self.buffer = messages[-1] |
|
|
|
|
|
|
|
def message_received(self, message): |
|
|
|
logging.info(f'Message received: {message!r}') |
|
|
|
if message.startswith(b'PING '): |
|
|
|
self.send(b'PONG ' + message[5:]) |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
logging.info('The server closed the connection') |
|
|
|
self.connected = False |
|
|
|
self.connectionClosedEvent.set() |
|
|
|
|
|
|
|
|
|
|
|
class IRCClient: |
|
|
|
def __init__(self, messageQueue, config): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.config = config |
|
|
|
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()} |
|
|
|
|
|
|
|
self._transport = None |
|
|
|
self._protocol = None |
|
|
|
|
|
|
|
def update_config(self, config): |
|
|
|
needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl) |
|
|
|
self.config = config |
|
|
|
if self._transport: # if currently connected: |
|
|
|
if needReconnect: |
|
|
|
self._transport.close() |
|
|
|
else: |
|
|
|
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()} |
|
|
|
self._protocol.update_channels(self.channels) |
|
|
|
|
|
|
|
async def run(self, loop, sigintEvent): |
|
|
|
connectionClosedEvent = asyncio.Event() |
|
|
|
while True: |
|
|
|
connectionClosedEvent.clear() |
|
|
|
try: |
|
|
|
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = SSL_CONTEXTS[self.config.irc.ssl]) |
|
|
|
try: |
|
|
|
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
finally: |
|
|
|
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? |
|
|
|
except (ConnectionRefusedError, asyncio.TimeoutError) as e: |
|
|
|
logging.error(str(e)) |
|
|
|
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
if sigintEvent.is_set(): |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
class WebServer: |
|
|
|
def __init__(self, messageQueue, config): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.config = config |
|
|
|
|
|
|
|
self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth |
|
|
|
|
|
|
|
self._app = aiohttp.web.Application() |
|
|
|
self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)]) |
|
|
|
|
|
|
|
self.update_config(config) |
|
|
|
|
|
|
|
def update_config(self, config): |
|
|
|
self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()} |
|
|
|
needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port) |
|
|
|
self.config = config |
|
|
|
if needRebind: |
|
|
|
#TODO |
|
|
|
logging.error('Webserver host or port changes while running are currently not supported') |
|
|
|
|
|
|
|
async def run(self, stopEvent): |
|
|
|
runner = aiohttp.web.AppRunner(self._app) |
|
|
|
await runner.setup() |
|
|
|
site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port) |
|
|
|
await site.start() |
|
|
|
await stopEvent.wait() |
|
|
|
await runner.cleanup() |
|
|
|
|
|
|
|
async def post(self, request): |
|
|
|
logging.info(f'Received request for {request.path!r} with data {await request.read()!r}') |
|
|
|
try: |
|
|
|
channel, auth = self._paths[request.path] |
|
|
|
except KeyError: |
|
|
|
raise aiohttp.web.HTTPNotFound() |
|
|
|
if auth: |
|
|
|
authHeader = request.headers.get('Authorization') |
|
|
|
if not authHeader or authHeader != auth: |
|
|
|
raise aiohttp.web.HTTPForbidden() |
|
|
|
try: |
|
|
|
data = await request.json() |
|
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e: |
|
|
|
logging.error(f'Invalid data received: {await request.read()!r}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
if 'message' not in data: |
|
|
|
logging.error(f'Message missing: {await request.read()!r}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
if '\r' in data['message'] or '\n' in data['message']: |
|
|
|
logging.error(f'Linebreaks in message: {await request.read()!r}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
logging.debug(f'Putting message {data["message"]!r} for {channel} into message queue') |
|
|
|
self.messageQueue.put_nowait((channel, data['message'])) |
|
|
|
raise aiohttp.web.HTTPOk() |
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
|
if len(sys.argv) != 2: |
|
|
|
print('Usage: web2irc.py CONFIGFILE', file = sys.stderr) |
|
|
|
sys.exit(1) |
|
|
|
configFile = sys.argv[1] |
|
|
|
config = Config(configFile) |
|
|
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
|
|
messageQueue = MessageQueue() |
|
|
|
|
|
|
|
irc = IRCClient(messageQueue, config) |
|
|
|
webserver = WebServer(messageQueue, config) |
|
|
|
|
|
|
|
sigintEvent = asyncio.Event() |
|
|
|
def sigint_callback(): |
|
|
|
logging.info('Got SIGINT') |
|
|
|
nonlocal sigintEvent |
|
|
|
sigintEvent.set() |
|
|
|
loop.add_signal_handler(signal.SIGINT, sigint_callback) |
|
|
|
|
|
|
|
def sigusr1_callback(): |
|
|
|
logging.info('Got SIGUSR1, reloading config') |
|
|
|
nonlocal config, irc, webserver |
|
|
|
newConfig = config.reread() |
|
|
|
config = newConfig |
|
|
|
irc.update_config(config) |
|
|
|
webserver.update_config(config) |
|
|
|
loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback) |
|
|
|
|
|
|
|
await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
asyncio.run(main()) |