|
- import aiohttp
- import aiohttp.web
- import asyncio
- import base64
- import collections
- import concurrent.futures
- import logging
- import os.path
- import signal
- import ssl
- import string
- import sys
- import toml
-
-
- SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
-
-
- class InvalidConfig(Exception):
- '''Error in configuration file'''
-
-
- def is_valid_pem(path, withCert):
- '''Very basic check whether something looks like a valid PEM certificate'''
- try:
- with open(path, 'rb') as fp:
- contents = fp.read()
-
- # All of these raise exceptions if something's wrong...
- if withCert:
- assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n')
- endCertPos = contents.index(b'-----END CERTIFICATE-----\n')
- base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True)
- assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n')
- else:
- assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n')
- endCertPos = -26 # Please shoot me.
- endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n')
- base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True)
- assert contents[endKeyPos + 26:] == b''
- return True
- except: # Yes, really
- return False
-
-
- class Config(dict):
- def __init__(self, filename):
- super().__init__()
- self._filename = filename
-
- with open(self._filename, 'r') as fp:
- obj = toml.load(fp)
-
- logging.info(repr(obj))
-
- # Sanity checks
- if any(x not in ('logging', 'irc', 'web', 'maps') for x in obj.keys()):
- raise InvalidConfig('Unknown sections found in base object')
- if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
- raise InvalidConfig('Invalid section type(s), expected objects/dicts')
- if 'logging' in obj:
- if any(x not in ('level', 'format') for x in obj['logging']):
- raise InvalidConfig('Unknown key found in log section')
- if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'):
- raise InvalidConfig('Invalid log level')
- if 'format' in obj['logging']:
- if not isinstance(obj['logging']['format'], str):
- raise InvalidConfig('Invalid log format')
- try:
- #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion)
- # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text).
- assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0
- except (ValueError, AssertionError) as e:
- raise InvalidConfig('Invalid log format: parsing failed') from e
- if 'irc' in obj:
- if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
- raise InvalidConfig('Unknown key found in irc section')
- if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
- raise InvalidConfig('Invalid IRC host')
- if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
- raise InvalidConfig('Invalid IRC port')
- if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
- raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
- if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
- raise InvalidConfig('Invalid IRC nick')
- if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
- raise InvalidConfig('Invalid IRC realname')
- if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']):
- raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile')
- if 'certfile' in obj['irc']:
- if not isinstance(obj['irc']['certfile'], str):
- raise InvalidConfig('Invalid certificate file: not a string')
- if not os.path.isfile(obj['irc']['certfile']):
- raise InvalidConfig('Invalid certificate file: not a regular file')
- if not is_valid_pem(obj['irc']['certfile'], True):
- raise InvalidConfig('Invalid certificate file: not a valid PEM cert')
- if 'certkeyfile' in obj['irc']:
- if not isinstance(obj['irc']['certkeyfile'], str):
- raise InvalidConfig('Invalid certificate key file: not a string')
- if not os.path.isfile(obj['irc']['certkeyfile']):
- raise InvalidConfig('Invalid certificate key file: not a regular file')
- if not is_valid_pem(obj['irc']['certkeyfile'], False):
- raise InvalidConfig('Invalid certificate key file: not a valid PEM key')
- if 'web' in obj:
- if any(x not in ('host', 'port') for x in obj['web']):
- raise InvalidConfig('Unknown key found in web section')
- if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
- raise InvalidConfig('Invalid web hostname')
- if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
- raise InvalidConfig('Invalid web port')
- if 'maps' in obj:
- for key, map_ in obj['maps'].items():
- if not isinstance(key, str) or not key:
- raise InvalidConfig(f'Invalid map key {key!r}')
- if not isinstance(map_, collections.abc.Mapping):
- raise InvalidConfig(f'Invalid map for {key!r}')
- if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_):
- raise InvalidConfig(f'Unknown key(s) found in map {key!r}')
- #TODO: Check values
-
- # Default values
- finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
-
- # Fill in default values for the maps
- for key, map_ in obj['maps'].items():
- if 'webpath' not in map_:
- map_['webpath'] = f'/{key}'
- if 'ircchannel' not in map_:
- map_['ircchannel'] = f'#{key}'
- if 'auth' not in map_:
- map_['auth'] = False
-
- # Merge in what was read from the config file and set keys on self
- for key in ('logging', 'irc', 'web', 'maps'):
- if key in obj:
- finalObj[key].update(obj[key])
- self[key] = finalObj[key]
-
- def __repr__(self):
- return f'<Config(logging={self["logging"]!r}, irc={self["irc"]!r}, web={self["web"]!r}, maps={self["maps"]!r})>'
-
- def reread(self):
- return Config(self._filename)
-
-
- class MessageQueue:
- # An object holding onto the messages received from nodeping
- # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
- # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
- # Differences to asyncio.Queue include:
- # - No maxsize
- # - No put coroutine (not necessary since the queue can never be full)
- # - Only one concurrent getter
- # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
-
- def __init__(self):
- self._getter = None # None | asyncio.Future
- self._queue = collections.deque()
-
- async def get(self):
- if self._getter is not None:
- raise RuntimeError('Cannot get concurrently')
- if len(self._queue) == 0:
- self._getter = asyncio.get_running_loop().create_future()
- logging.debug('Awaiting getter')
- try:
- await self._getter
- except asyncio.CancelledError:
- logging.debug('Cancelled getter')
- self._getter = None
- raise
- logging.debug('Awaited getter')
- self._getter = None
- # For testing the cancellation/putting back onto the queue
- #logging.debug('Delaying message queue get')
- #await asyncio.sleep(3)
- #logging.debug('Done delaying')
- return self.get_nowait()
-
- def get_nowait(self):
- if len(self._queue) == 0:
- raise asyncio.QueueEmpty
- return self._queue.popleft()
-
- def put_nowait(self, item):
- self._queue.append(item)
- if self._getter is not None and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def putleft_nowait(self, *item):
- self._queue.extendleft(reversed(item))
- if self._getter is not None and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def qsize(self):
- return len(self._queue)
-
-
- class IRCClientProtocol(asyncio.Protocol):
- def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
- logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}')
- self.messageQueue = messageQueue
- self.connectionClosedEvent = connectionClosedEvent
- self.loop = loop
- self.config = config
- self.buffer = b''
- self.connected = False
- self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
- self.unconfirmedMessages = []
- self.pongReceivedEvent = asyncio.Event()
-
- def connection_made(self, transport):
- logging.info('Connected')
- self.transport = transport
- self.connected = True
- nickb = self.config['irc']['nick'].encode('utf-8')
- self.send(b'NICK ' + nickb)
- self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8'))
-
- def update_channels(self, channels: set):
- channelsToPart = self.channels - channels
- channelsToJoin = channels - self.channels
- self.channels = channels
-
- if self.connected:
- if channelsToPart:
- #TODO: Split if too long
- self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8'))
- if channelsToJoin:
- self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8'))
-
- def send(self, data):
- logging.info(f'Send: {data!r}')
- self.transport.write(data + b'\r\n')
-
- async def _get_message(self):
- logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
- messageFuture = asyncio.create_task(self.messageQueue.get())
- done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if self.connectionClosedEvent.is_set():
- if messageFuture in pending:
- logging.debug('Cancelling messageFuture')
- messageFuture.cancel()
- try:
- await messageFuture
- except asyncio.CancelledError:
- logging.debug('Cancelled messageFuture')
- pass
- else:
- # messageFuture is already done but we're stopping, so put the result back onto the queue
- self.messageQueue.putleft_nowait(messageFuture.result())
- return None, None
- assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
- return messageFuture.result()
-
- async def send_messages(self):
- while self.connected:
- logging.debug(f'{id(self)}: trying to get a message')
- channel, message = await self._get_message()
- logging.debug(f'{id(self)}: got message: {message!r}')
- if message is None:
- break
- #TODO Split if the message is too long.
- self.unconfirmedMessages.append((channel, message))
- self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8'))
- await asyncio.sleep(1) # Rate limit
-
- async def confirm_messages(self):
- while self.connected:
- await asyncio.wait((asyncio.sleep(60), self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) # Confirm once per minute
- if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.unconfirmedMessages = []
- break
- if not self.unconfirmedMessages:
- logging.debug(f'{id(self)}: no messages to confirm')
- continue
- logging.debug(f'{id(self)}: trying to confirm message delivery')
- self.pongReceivedEvent.clear()
- self.send(b'PING :42')
- await asyncio.wait((asyncio.sleep(5), self.pongReceivedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- logging.debug(f'{id(self)}: message delivery success: {self.pongReceivedEvent.is_set()}')
- if not self.pongReceivedEvent.is_set():
- # No PONG received in five seconds, assume connection's dead
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.transport.close()
- self.unconfirmedMessages = []
-
- def data_received(self, data):
- logging.debug(f'Data received: {data!r}')
- # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
- # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
- # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
- messages = data.split(b'\r\n')
- if self.buffer:
- self.message_received(self.buffer + messages[0])
- messages = messages[1:]
- for message in messages[:-1]:
- self.message_received(message)
- self.buffer = messages[-1]
-
- def message_received(self, message):
- logging.info(f'Message received: {message!r}')
- if message.startswith(b':'):
- # Prefixed message, extract command + parameters (the prefix cannot contain a space)
- message = message.split(b' ', 1)[1]
- if message.startswith(b'PING '):
- self.send(b'PONG ' + message[5:])
- elif message.startswith(b'PONG '):
- self.pongReceivedEvent.set()
- elif message.startswith(b'001 '):
- # Connection registered
- self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long
- asyncio.create_task(self.send_messages())
- asyncio.create_task(self.confirm_messages())
-
- def connection_lost(self, exc):
- logging.info('The server closed the connection')
- self.connected = False
- self.connectionClosedEvent.set()
-
-
- class IRCClient:
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
-
- self._transport = None
- self._protocol = None
-
- def update_config(self, config):
- needReconnect = self.config['irc'] != config['irc']
- self.config = config
- if self._transport: # if currently connected:
- if needReconnect:
- self._transport.close()
- else:
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
- self._protocol.update_channels(self.channels)
-
- def _get_ssl_context(self):
- ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
- if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
- if ctx is True:
- ctx = ssl.create_default_context()
- if isinstance(ctx, ssl.SSLContext):
- ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
- return ctx
-
- async def run(self, loop, sigintEvent):
- connectionClosedEvent = asyncio.Event()
- while True:
- connectionClosedEvent.clear()
- try:
- self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
- try:
- await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- finally:
- self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
- except (ConnectionRefusedError, asyncio.TimeoutError) as e:
- logging.error(str(e))
- await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if sigintEvent.is_set():
- break
-
-
- class WebServer:
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
-
- self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth
-
- self._app = aiohttp.web.Application()
- self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)])
-
- self.update_config(config)
- self._configChanged = asyncio.Event()
-
- def update_config(self, config):
- self._paths = {map_['webpath']: (map_['ircchannel'], f'Basic {base64.b64encode(map_["auth"].encode("utf-8")).decode("utf-8")}' if map_['auth'] else False) for map_ in config['maps'].values()}
- needRebind = self.config['web'] != config['web']
- self.config = config
- if needRebind:
- self._configChanged.set()
-
- async def run(self, stopEvent):
- while True:
- runner = aiohttp.web.AppRunner(self._app)
- await runner.setup()
- site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
- await site.start()
- await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- await runner.cleanup()
- if stopEvent.is_set():
- break
- self._configChanged.clear()
-
- async def post(self, request):
- logging.info(f'Received request for {request.path!r}')
- try:
- channel, auth = self._paths[request.path]
- except KeyError:
- logging.info(f'Bad request: no path {request.path!r}')
- raise aiohttp.web.HTTPNotFound()
- if auth:
- authHeader = request.headers.get('Authorization')
- if not authHeader or authHeader != auth:
- logging.info(f'Bad request: authentication failed: {authHeader!r} != {auth}')
- raise aiohttp.web.HTTPForbidden()
- try:
- message = await request.text()
- except Exception as e:
- logging.info(f'Bad request: exception while reading request data: {e!s}')
- raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-)
- logging.debug(f'Request payload: {message!r}')
- # Strip optional [CR] LF at the end of the payload
- if message.endswith('\r\n'):
- message = message[:-2]
- elif message.endswith('\n'):
- message = message[:-1]
- if '\r' in message or '\n' in message:
- logging.info('Bad request: linebreaks in message')
- raise aiohttp.web.HTTPBadRequest()
- logging.debug(f'Putting message {message!r} for {channel} into message queue')
- self.messageQueue.put_nowait((channel, message))
- raise aiohttp.web.HTTPOk()
-
-
- def configure_logging(config):
- #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+)
- root = logging.getLogger()
- root.setLevel(getattr(logging, config['logging']['level']))
- root.handlers = [] #FIXME: Undocumented attribute of logging.Logger
- formatter = logging.Formatter(config['logging']['format'], style = '{')
- stderrHandler = logging.StreamHandler()
- stderrHandler.setFormatter(formatter)
- root.addHandler(stderrHandler)
-
-
- async def main():
- if len(sys.argv) != 2:
- print('Usage: http2irc.py CONFIGFILE', file = sys.stderr)
- sys.exit(1)
- configFile = sys.argv[1]
- config = Config(configFile)
- configure_logging(config)
-
- loop = asyncio.get_running_loop()
-
- messageQueue = MessageQueue()
-
- irc = IRCClient(messageQueue, config)
- webserver = WebServer(messageQueue, config)
-
- sigintEvent = asyncio.Event()
- def sigint_callback():
- logging.info('Got SIGINT')
- nonlocal sigintEvent
- sigintEvent.set()
- loop.add_signal_handler(signal.SIGINT, sigint_callback)
-
- def sigusr1_callback():
- logging.info('Got SIGUSR1, reloading config')
- nonlocal config, irc, webserver
- try:
- newConfig = config.reread()
- except InvalidConfig as e:
- logging.error(f'Config reload failed: {e!s}')
- return
- config = newConfig
- configure_logging(config)
- irc.update_config(config)
- webserver.update_config(config)
- loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)
-
- await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent))
-
-
- if __name__ == '__main__':
- asyncio.run(main())
|