diff --git a/http2irc.py b/http2irc.py new file mode 100644 index 0000000..dd786b4 --- /dev/null +++ b/http2irc.py @@ -0,0 +1,375 @@ +import aiohttp +import aiohttp.web +import asyncio +import base64 +import collections +import concurrent.futures +import json +import logging +import signal +import ssl +import sys +import toml +import types + + +logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') + + +SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} + + +class InvalidConfig(Exception): + '''Error in configuration file''' + + +def _mapping_to_namespace(d): + '''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively''' + return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()}) + + +class Config: + def __init__(self, filename): + self._filename = filename + # Set below: + self.irc = None + self.web = None + self.maps = None + + with open(self._filename, 'r') as fp: + obj = toml.load(fp) + + logging.info(repr(obj)) + + # Sanity checks + if any(x not in ('irc', 'web', 'maps') for x in obj.keys()): + raise InvalidConfig('Unknown sections found in base object') + if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()): + raise InvalidConfig('Invalid section type(s), expected objects/dicts') + if 'irc' in obj: + if any(x not in ('host', 'port', 'ssl', 'nick', 'real') for x in obj['irc']): + raise InvalidConfig('Unknown key found in irc section') + if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname + raise InvalidConfig('Invalid IRC host') + if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535): + raise InvalidConfig('Invalid IRC port') + if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): + raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') + if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname + raise InvalidConfig('Invalid IRC nick') + if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): + raise InvalidConfig('Invalid IRC realname') + if 'web' in obj: + if any(x not in ('host', 'port') for x in obj['web']): + raise InvalidConfig('Unknown key found in web section') + if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?) + raise InvalidConfig('Invalid web hostname') + if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535): + raise InvalidConfig('Invalid web port') + if 'maps' in obj: + for key, map_ in obj['maps'].items(): + # Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace. + #TODO: Support for fancier identifiers (PEP 3131)? + if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '': + raise InvalidConfig(f'Invalid map key {key!r}') + if not isinstance(map_, collections.abc.Mapping): + raise InvalidConfig(f'Invalid map for {key!r}') + if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_): + raise InvalidConfig(f'Unknown key(s) found in map {key!r}') + #TODO: Check values + + # Default values + self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.'}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}} + + # Fill in default values for the maps + for key, map_ in obj['maps'].items(): + if 'webpath' not in map_: + map_['webpath'] = f'/{key}' + if 'ircchannel' not in map_: + map_['ircchannel'] = f'#{key}' + if 'auth' not in map_: + map_['auth'] = False + + # Merge in what was read from the config file and convert to SimpleNamespace + for key in ('irc', 'web', 'maps'): + if key in obj: + self._obj[key].update(obj[key]) + setattr(self, key, _mapping_to_namespace(self._obj[key])) + + def __repr__(self): + return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})' + + def reread(self): + return Config(self._filename) + + +class MessageQueue: + # An object holding onto the messages received from nodeping + # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. + # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. + # Differences to asyncio.Queue include: + # - No maxsize + # - No put coroutine (not necessary since the queue can never be full) + # - Only one concurrent getter + # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) + + def __init__(self): + self._getter = None # None | asyncio.Future + self._queue = collections.deque() + + async def get(self): + if self._getter is not None: + raise RuntimeError('Cannot get concurrently') + if len(self._queue) == 0: + self._getter = asyncio.get_running_loop().create_future() + logging.debug('Awaiting getter') + try: + await self._getter + except asyncio.CancelledError: + logging.debug('Cancelled getter') + self._getter = None + raise + logging.debug('Awaited getter') + self._getter = None + # For testing the cancellation/putting back onto the queue + #logging.debug('Delaying message queue get') + #await asyncio.sleep(3) + #logging.debug('Done delaying') + return self.get_nowait() + + def get_nowait(self): + if len(self._queue) == 0: + raise asyncio.QueueEmpty + return self._queue.popleft() + + def put_nowait(self, item): + self._queue.append(item) + if self._getter is not None: + self._getter.set_result(None) + + def putleft_nowait(self, item): + self._queue.appendleft(item) + if self._getter is not None: + self._getter.set_result(None) + + def qsize(self): + return len(self._queue) + + +class IRCClientProtocol(asyncio.Protocol): + def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): + logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}') + self.messageQueue = messageQueue + self.connectionClosedEvent = connectionClosedEvent + self.loop = loop + self.config = config + self.buffer = b'' + self.connected = False + self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) + + def connection_made(self, transport): + logging.info('Connected') + self.transport = transport + self.connected = True + nickb = self.config.irc.nick.encode('utf-8') + self.send(b'NICK ' + nickb) + self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8')) + self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long + asyncio.create_task(self.send_messages()) + + def update_channels(self, channels: set): + channelsToPart = self.channels - channels + channelsToJoin = channels - self.channels + self.channels = channels + + if self.connected: + if channelsToPart: + #TODO: Split if too long + self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8')) + if channelsToJoin: + self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8')) + + def send(self, data): + logging.info(f'Send: {data!r}') + self.transport.write(data + b'\r\n') + + async def _get_message(self): + logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') + messageFuture = asyncio.create_task(self.messageQueue.get()) + done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + if self.connectionClosedEvent.is_set(): + if messageFuture in pending: + logging.debug('Cancelling messageFuture') + messageFuture.cancel() + try: + await messageFuture + except asyncio.CancelledError: + logging.debug('Cancelled messageFuture') + pass + else: + # messageFuture is already done but we're stopping, so put the result back onto the queue + self.messageQueue.putleft_nowait(messageFuture.result()) + return None, None + assert messageFuture in done, 'Invalid state: messageFuture not in done futures' + return messageFuture.result() + + async def send_messages(self): + while self.connected: + logging.debug(f'{id(self)}: trying to get a message') + channel, message = await self._get_message() + logging.debug(f'{id(self)}: got message: {message!r}') + if message is None: + break + self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8')) + #TODO self.messageQueue.putleft_nowait if delivery fails + await asyncio.sleep(1) # Rate limit + + def data_received(self, data): + logging.debug(f'Data received: {data!r}') + # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. + # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. + # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. + messages = data.split(b'\r\n') + if self.buffer: + self.message_received(self.buffer + messages[0]) + messages = messages[1:] + for message in messages[:-1]: + self.message_received(message) + self.buffer = messages[-1] + + def message_received(self, message): + logging.info(f'Message received: {message!r}') + if message.startswith(b'PING '): + self.send(b'PONG ' + message[5:]) + + def connection_lost(self, exc): + logging.info('The server closed the connection') + self.connected = False + self.connectionClosedEvent.set() + + +class IRCClient: + def __init__(self, messageQueue, config): + self.messageQueue = messageQueue + self.config = config + self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()} + + self._transport = None + self._protocol = None + + def update_config(self, config): + needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl) + self.config = config + if self._transport: # if currently connected: + if needReconnect: + self._transport.close() + else: + self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()} + self._protocol.update_channels(self.channels) + + async def run(self, loop, sigintEvent): + connectionClosedEvent = asyncio.Event() + while True: + connectionClosedEvent.clear() + try: + self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = SSL_CONTEXTS[self.config.irc.ssl]) + try: + await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + finally: + self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? + except (ConnectionRefusedError, asyncio.TimeoutError) as e: + logging.error(str(e)) + await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) + if sigintEvent.is_set(): + break + + +class WebServer: + def __init__(self, messageQueue, config): + self.messageQueue = messageQueue + self.config = config + + self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth + + self._app = aiohttp.web.Application() + self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)]) + + self.update_config(config) + + def update_config(self, config): + self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()} + needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port) + self.config = config + if needRebind: + #TODO + logging.error('Webserver host or port changes while running are currently not supported') + + async def run(self, stopEvent): + runner = aiohttp.web.AppRunner(self._app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port) + await site.start() + await stopEvent.wait() + await runner.cleanup() + + async def post(self, request): + logging.info(f'Received request for {request.path!r} with data {await request.read()!r}') + try: + channel, auth = self._paths[request.path] + except KeyError: + raise aiohttp.web.HTTPNotFound() + if auth: + authHeader = request.headers.get('Authorization') + if not authHeader or authHeader != auth: + raise aiohttp.web.HTTPForbidden() + try: + data = await request.json() + except (aiohttp.ContentTypeError, json.JSONDecodeError) as e: + logging.error(f'Invalid data received: {await request.read()!r}') + raise aiohttp.web.HTTPBadRequest() + if 'message' not in data: + logging.error(f'Message missing: {await request.read()!r}') + raise aiohttp.web.HTTPBadRequest() + if '\r' in data['message'] or '\n' in data['message']: + logging.error(f'Linebreaks in message: {await request.read()!r}') + raise aiohttp.web.HTTPBadRequest() + logging.debug(f'Putting message {data["message"]!r} for {channel} into message queue') + self.messageQueue.put_nowait((channel, data['message'])) + raise aiohttp.web.HTTPOk() + + +async def main(): + if len(sys.argv) != 2: + print('Usage: web2irc.py CONFIGFILE', file = sys.stderr) + sys.exit(1) + configFile = sys.argv[1] + config = Config(configFile) + + loop = asyncio.get_running_loop() + + messageQueue = MessageQueue() + + irc = IRCClient(messageQueue, config) + webserver = WebServer(messageQueue, config) + + sigintEvent = asyncio.Event() + def sigint_callback(): + logging.info('Got SIGINT') + nonlocal sigintEvent + sigintEvent.set() + loop.add_signal_handler(signal.SIGINT, sigint_callback) + + def sigusr1_callback(): + logging.info('Got SIGUSR1, reloading config') + nonlocal config, irc, webserver + newConfig = config.reread() + config = newConfig + irc.update_config(config) + webserver.update_config(config) + loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback) + + await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent)) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/nodeping2irc.py b/nodeping2irc.py deleted file mode 100644 index 5d8860d..0000000 --- a/nodeping2irc.py +++ /dev/null @@ -1,246 +0,0 @@ -import aiohttp -import aiohttp.web -import argparse -import asyncio -import collections -import concurrent.futures -import json -import logging -import signal - - -logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') - - -class MessageQueue: - # An object holding onto the messages received from nodeping - # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. - # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. - # Differences to asyncio.Queue include: - # - No maxsize - # - No put coroutine (not necessary since the queue can never be full) - # - Only one concurrent getter - # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) - - def __init__(self): - self._getter = None # None | asyncio.Future - self._queue = collections.deque() - - async def get(self): - if self._getter is not None: - raise RuntimeError('Cannot get concurrently') - if len(self._queue) == 0: - self._getter = asyncio.get_running_loop().create_future() - logging.debug('Awaiting getter') - try: - await self._getter - except asyncio.CancelledError: - logging.debug('Cancelled getter') - self._getter = None - raise - logging.debug('Awaited getter') - self._getter = None - # For testing the cancellation/putting back onto the queue - #logging.debug('Delaying message queue get') - #await asyncio.sleep(3) - #logging.debug('Done delaying') - return self.get_nowait() - - def get_nowait(self): - if len(self._queue) == 0: - raise asyncio.QueueEmpty - return self._queue.popleft() - - def put_nowait(self, item): - self._queue.append(item) - if self._getter is not None: - self._getter.set_result(None) - - def putleft_nowait(self, item): - self._queue.appendleft(item) - if self._getter is not None: - self._getter.set_result(None) - - def qsize(self): - return len(self._queue) - - -class IRCClientProtocol(asyncio.Protocol): - def __init__(self, messageQueue, stopEvent, loop, nick, real, channel): - logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}') - self.messageQueue = messageQueue - self.stopEvent = stopEvent - self.loop = loop - self.nick = nick - self.real = real - self.channel = channel - self.channelb = channel.encode('utf-8') - self.buffer = b'' - self.connected = False - - def send(self, data): - logging.info(f'Send: {data!r}') - self.transport.write(data + b'\r\n') - - def connection_made(self, transport): - logging.info('Connected') - self.transport = transport - self.connected = True - nickb = self.nick.encode('utf-8') - self.send(b'NICK ' + nickb) - self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.real.encode('utf-8')) - self.send(b'JOIN ' + self.channelb) - asyncio.create_task(self.send_messages()) - - async def _get_message(self): - logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') - messageFuture = asyncio.create_task(self.messageQueue.get()) - done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) - if self.stopEvent.is_set(): - if messageFuture in pending: - logging.debug('Cancelling messageFuture') - messageFuture.cancel() - try: - await messageFuture - except asyncio.CancelledError: - logging.debug('Cancelled messageFuture') - pass - else: - # messageFuture is already done but we're stopping, so put the result back onto the queue - self.messageQueue.putleft_nowait(messageFuture.result()) - return None - assert messageFuture in done, 'Invalid state: messageFuture not in done futures' - return messageFuture.result() - - async def send_messages(self): - while self.connected: - logging.debug(f'{id(self)}: trying to get a message') - message = await self._get_message() - logging.debug(f'{id(self)}: got message: {message!r}') - if message is None: - break - self.send(b'PRIVMSG ' + self.channelb + b' :' + message.encode('utf-8')) - #TODO self.messageQueue.putleft_nowait if delivery fails - await asyncio.sleep(1) # Rate limit - - def data_received(self, data): - logging.debug(f'Data received: {data!r}') - # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. - # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. - # If data does end with CRLF, all messages will have been processed and the buffer will be empty again. - messages = data.split(b'\r\n') - if self.buffer: - self.message_received(self.buffer + messages[0]) - messages = messages[1:] - for message in messages[:-1]: - self.message_received(message) - self.buffer = messages[-1] - - def message_received(self, message): - logging.info(f'Message received: {message!r}') - if message.startswith(b'PING '): - self.send(b'PONG ' + message[5:]) - - def connection_lost(self, exc): - logging.info('The server closed the connection') - self.connected = False - self.stopEvent.set() - - -class WebServer: - def __init__(self, messageQueue, host, port, auth): - self.messageQueue = messageQueue - self.host = host - self.port = port - self.auth = auth - if auth: - self.authHeader = f'Basic {base64.b64encode(auth.encode("utf-8")).decode("utf-8")}' - self._app = aiohttp.web.Application() - self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)]) - - async def run(self, stopEvent): - runner = aiohttp.web.AppRunner(self._app) - await runner.setup() - site = aiohttp.web.TCPSite(runner, self.host, self.port) - await site.start() - await stopEvent.wait() - await runner.cleanup() - - async def nodeping_post(self, request): - logging.info(f'Received request with data: {await request.read()!r}') - authHeader = request.headers.get('Authorization') - if self.auth and (not authHeader or authHeader != self.authHeader): - return aiohttp.web.HTTPForbidden() - try: - data = await request.json() - except (aiohttp.ContentTypeError, json.JSONDecodeError) as e: - logging.error(f'Received invalid data: {await request.read()!r}') - return aiohttp.web.HTTPBadRequest() - if 'message' not in data: - logging.error(f'Received invalid data: {await request.read()!r}') - return aiohttp.web.HTTPBadRequest() - if '\r' in data['message'] or '\n' in data['message']: - logging.error(f'Received invalid data: {await request.read()!r}') - return aiohttp.web.HTTPBadRequest() - logging.debug(f'Putting to message queue {id(self.messageQueue)}') - self.messageQueue.put_nowait(data['message']) - return aiohttp.web.HTTPOk() - - -async def run_irc(loop, messageQueue, sigintEvent, host, port, ssl, nick, real, channel): - stopEvent = asyncio.Event() - while True: - stopEvent.clear() - try: - transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop, nick = nick, real = real, channel = channel), host, port, ssl = ssl) - try: - await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) - finally: - transport.close() - except (ConnectionRefusedError, asyncio.TimeoutError) as e: - logging.error(str(e)) - await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) - if sigintEvent.is_set(): - break - - -async def run_webserver(loop, messageQueue, sigintEvent, host, port, auth): - server = WebServer(messageQueue, host, port, auth) - await server.run(sigintEvent) - - -def parse_args(): - parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--irchost', type = str, help = 'IRC server hostname', default = 'irc.hackint.org') - parser.add_argument('--ircport', type = int, help = 'IRC server port', default = 6697) - parser.add_argument('--ircssl', choices = ['yes', 'no', 'insecure'], help = 'enable, disable, or use insecure SSL/TLS', default = 'yes') - parser.add_argument('--ircnick', help = 'IRC nickname', default = 'npbot') - parser.add_argument('--ircreal', help = 'IRC realname', default = 'I am a bot.') - parser.add_argument('--ircchannel', help = 'IRC channel to join and post messages', default = '#nodeping') - parser.add_argument('--webhost', type = str, help = 'web server host to bind to', default = '127.0.0.1') - parser.add_argument('--webport', type = int, help = 'web server port to bind to', default = 8080) - parser.add_argument('--webauth', type = str, help = 'basic auth data (user:pass, or None to disable the check)', default = None) - return parser.parse_args() - - -async def main(): - args = parse_args() - ssl = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}[args.ircssl] - - loop = asyncio.get_running_loop() - - messageQueue = MessageQueue() - sigintEvent = asyncio.Event() - - def sigint_callback(): - logging.info('Got SIGINT') - nonlocal sigintEvent - sigintEvent.set() - loop.add_signal_handler(signal.SIGINT, sigint_callback) - - irc = run_irc(loop, messageQueue, sigintEvent, host = args.irchost, port = args.ircport, ssl = ssl, nick = args.ircnick, real = args.ircreal, channel = args.ircchannel) - webserver = run_webserver(loop, messageQueue, sigintEvent, host = args.webhost, port = args.webport, auth = args.webauth) - await asyncio.gather(irc, webserver) - - -asyncio.run(main())