|
- import aiohttp
- import aiohttp.web
- import asyncio
- import base64
- import collections
- import concurrent.futures
- import importlib.util
- import inspect
- import itertools
- import logging
- import os.path
- import signal
- import ssl
- import string
- import sys
- import toml
-
-
- logger = logging.getLogger('http2irc')
- SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
-
-
- class InvalidConfig(Exception):
- '''Error in configuration file'''
-
-
- def is_valid_pem(path, withCert):
- '''Very basic check whether something looks like a valid PEM certificate'''
- try:
- with open(path, 'rb') as fp:
- contents = fp.read()
-
- # All of these raise exceptions if something's wrong...
- if withCert:
- assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n')
- endCertPos = contents.index(b'-----END CERTIFICATE-----\n')
- base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True)
- assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n')
- else:
- assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n')
- endCertPos = -26 # Please shoot me.
- endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n')
- base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True)
- assert contents[endKeyPos + 26:] == b''
- return True
- except: # Yes, really
- return False
-
-
- class Config(dict):
- def __init__(self, filename):
- super().__init__()
- self._filename = filename
-
- with open(self._filename, 'r') as fp:
- obj = toml.load(fp)
-
- # Sanity checks
- if any(x not in ('logging', 'irc', 'web', 'maps') for x in obj.keys()):
- raise InvalidConfig('Unknown sections found in base object')
- if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
- raise InvalidConfig('Invalid section type(s), expected objects/dicts')
- if 'logging' in obj:
- if any(x not in ('level', 'format') for x in obj['logging']):
- raise InvalidConfig('Unknown key found in log section')
- if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'):
- raise InvalidConfig('Invalid log level')
- if 'format' in obj['logging']:
- if not isinstance(obj['logging']['format'], str):
- raise InvalidConfig('Invalid log format')
- try:
- #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion)
- # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text).
- assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0
- except (ValueError, AssertionError) as e:
- raise InvalidConfig('Invalid log format: parsing failed') from e
- if 'irc' in obj:
- if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
- raise InvalidConfig('Unknown key found in irc section')
- if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
- raise InvalidConfig('Invalid IRC host')
- if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
- raise InvalidConfig('Invalid IRC port')
- if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
- raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
- if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
- raise InvalidConfig('Invalid IRC nick')
- if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
- raise InvalidConfig('Invalid IRC nick: NICK command too long')
- if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
- raise InvalidConfig('Invalid IRC realname')
- if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510:
- raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long')
- if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']):
- raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile')
- if 'certfile' in obj['irc']:
- if not isinstance(obj['irc']['certfile'], str):
- raise InvalidConfig('Invalid certificate file: not a string')
- if not os.path.isfile(obj['irc']['certfile']):
- raise InvalidConfig('Invalid certificate file: not a regular file')
- if not is_valid_pem(obj['irc']['certfile'], True):
- raise InvalidConfig('Invalid certificate file: not a valid PEM cert')
- if 'certkeyfile' in obj['irc']:
- if not isinstance(obj['irc']['certkeyfile'], str):
- raise InvalidConfig('Invalid certificate key file: not a string')
- if not os.path.isfile(obj['irc']['certkeyfile']):
- raise InvalidConfig('Invalid certificate key file: not a regular file')
- if not is_valid_pem(obj['irc']['certkeyfile'], False):
- raise InvalidConfig('Invalid certificate key file: not a valid PEM key')
- if 'web' in obj:
- if any(x not in ('host', 'port') for x in obj['web']):
- raise InvalidConfig('Unknown key found in web section')
- if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
- raise InvalidConfig('Invalid web hostname')
- if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
- raise InvalidConfig('Invalid web port')
- if 'maps' in obj:
- seenWebPaths = {}
- for key, map_ in obj['maps'].items():
- if not isinstance(key, str) or not key:
- raise InvalidConfig(f'Invalid map key {key!r}')
- if not isinstance(map_, collections.abc.Mapping):
- raise InvalidConfig(f'Invalid map for {key!r}')
- if any(x not in ('webpath', 'ircchannel', 'auth', 'module', 'moduleargs') for x in map_):
- raise InvalidConfig(f'Unknown key(s) found in map {key!r}')
-
- if 'webpath' not in map_:
- map_['webpath'] = f'/{key}'
- if not isinstance(map_['webpath'], str):
- raise InvalidConfig(f'Invalid map {key!r} web path: not a string')
- if not map_['webpath'].startswith('/'):
- raise InvalidConfig(f'Invalid map {key!r} web path: does not start at the root')
- if map_['webpath'] in seenWebPaths:
- raise InvalidConfig(f'Invalid map {key!r} web path: collides with map {seenWebPaths[map_["webpath"]]!r}')
- seenWebPaths[map_['webpath']] = key
-
- if 'ircchannel' not in map_:
- map_['ircchannel'] = f'#{key}'
- if not isinstance(map_['ircchannel'], str):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: not a string')
- if not map_['ircchannel'].startswith('#') and not map_['ircchannel'].startswith('&'):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: does not start with # or &')
- if any(x in map_['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: contains forbidden characters')
- if len(map_['ircchannel']) > 200:
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: too long')
-
- if 'auth' in map_:
- if map_['auth'] is not False and not isinstance(map_['auth'], str):
- raise InvalidConfig(f'Invalid map {key!r} auth: must be false or a string')
- if isinstance(map_['auth'], str) and ':' not in map_['auth']:
- raise InvalidConfig(f'Invalid map {key!r} auth: must contain a colon')
-
- if 'module' in map_ and not os.path.isfile(map_['module']):
- raise InvalidConfig(f'Module {map_["module"]!r} in map {key!r} is not a file')
- if 'moduleargs' in map_:
- if not isinstance(map_['moduleargs'], list):
- raise InvalidConfig(f'Invalid module args for {key!r}: not an array')
- if 'module' not in map_:
- raise InvalidConfig(f'Module args cannot be specified without a module for {key!r}')
-
- # Default values
- finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
-
- # Fill in default values for the maps
- for key, map_ in obj['maps'].items():
- # webpath is already set above for duplicate checking
- # ircchannel is set above for validation
- if 'auth' not in map_:
- map_['auth'] = False
- if 'module' not in map_:
- map_['module'] = None
- if 'moduleargs' not in map_:
- map_['moduleargs'] = []
-
- # Load modules
- modulePaths = {} # path: str -> (extraargs: int, key: str)
- for key, map_ in obj['maps'].items():
- if map_['module'] is not None:
- if map_['module'] not in modulePaths:
- modulePaths[map_['module']] = (len(map_['moduleargs']), key)
- elif modulePaths[map_['module']][0] != len(map_['moduleargs']):
- raise InvalidConfig(f'Module {map_["module"]!r} process function extra argument inconsistency between {key!r} and {modulePaths[map_["module"]][1]!r}')
- modules = {} # path: str -> module: module
- for i, (path, (extraargs, _)) in enumerate(modulePaths.items()):
- try:
- # Build a name that is virtually guaranteed to be unique across a process.
- # Although importlib does not seem to perform any caching as of CPython 3.8, this is not guaranteed by spec.
- spec = importlib.util.spec_from_file_location(f'http2irc-module-{id(self)}-{i}', path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- except Exception as e: # This is ugly, but exec_module can raise virtually any exception
- raise InvalidConfig(f'Loading module {path!r} failed: {e!s}')
- if not hasattr(module, 'process'):
- raise InvalidConfig(f'Module {path!r} does not have a process function')
- if not inspect.iscoroutinefunction(module.process):
- raise InvalidConfig(f'Module {path!r} process attribute is not a coroutine function')
- nargs = len(inspect.signature(module.process).parameters)
- if nargs != 1 + extraargs:
- raise InvalidConfig(f'Module {path!r} process function takes {nargs} parameter{"s" if nargs > 1 else ""}, not {1 + extraargs}')
- modules[path] = module
-
- # Replace module value in maps
- for map_ in obj['maps'].values():
- if 'module' in map_ and map_['module'] is not None:
- map_['module'] = modules[map_['module']]
-
- # Merge in what was read from the config file and set keys on self
- for key in ('logging', 'irc', 'web', 'maps'):
- if key in obj:
- finalObj[key].update(obj[key])
- self[key] = finalObj[key]
-
- def __repr__(self):
- return f'<Config(logging={self["logging"]!r}, irc={self["irc"]!r}, web={self["web"]!r}, maps={self["maps"]!r})>'
-
- def reread(self):
- return Config(self._filename)
-
-
- class MessageQueue:
- # An object holding onto the messages received from nodeping
- # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
- # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
- # Differences to asyncio.Queue include:
- # - No maxsize
- # - No put coroutine (not necessary since the queue can never be full)
- # - Only one concurrent getter
- # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
-
- logger = logging.getLogger('http2irc.MessageQueue')
-
- def __init__(self):
- self._getter = None # None | asyncio.Future
- self._queue = collections.deque()
-
- async def get(self):
- if self._getter is not None:
- raise RuntimeError('Cannot get concurrently')
- if len(self._queue) == 0:
- self._getter = asyncio.get_running_loop().create_future()
- self.logger.debug('Awaiting getter')
- try:
- await self._getter
- except asyncio.CancelledError:
- self.logger.debug('Cancelled getter')
- self._getter = None
- raise
- self.logger.debug('Awaited getter')
- self._getter = None
- # For testing the cancellation/putting back onto the queue
- #self.logger.debug('Delaying message queue get')
- #await asyncio.sleep(3)
- #self.logger.debug('Done delaying')
- return self.get_nowait()
-
- def get_nowait(self):
- if len(self._queue) == 0:
- raise asyncio.QueueEmpty
- return self._queue.popleft()
-
- def put_nowait(self, item):
- self._queue.append(item)
- if self._getter is not None and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def putleft_nowait(self, *item):
- self._queue.extendleft(reversed(item))
- if self._getter is not None and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def qsize(self):
- return len(self._queue)
-
-
- class IRCClientProtocol(asyncio.Protocol):
- logger = logging.getLogger('http2irc.IRCClientProtocol')
-
- def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
- self.messageQueue = messageQueue
- self.connectionClosedEvent = connectionClosedEvent
- self.loop = loop
- self.config = config
- self.buffer = b''
- self.connected = False
- self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
- self.unconfirmedMessages = []
- self.pongReceivedEvent = asyncio.Event()
- self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
- self.authenticated = False
- self.usermask = None
-
- @staticmethod
- def nick_command(nick: str):
- return b'NICK ' + nick.encode('utf-8')
-
- @staticmethod
- def user_command(nick: str, real: str):
- nickb = nick.encode('utf-8')
- return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')
-
- def _maybe_set_usermask(self, usermask):
- if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')):
- self.usermask = usermask
- self.logger.debug(f'Usermask is now {usermask!r}')
-
- def connection_made(self, transport):
- self.logger.info('IRC connected')
- self.transport = transport
- self.connected = True
- if self.sasl:
- self.send(b'CAP REQ :sasl')
- self.send(self.nick_command(self.config['irc']['nick']))
- self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))
-
- def _send_join_part(self, command, channels):
- '''Split a JOIN or PART into multiple messages as necessary'''
- # command: b'JOIN' or b'PART'; channels: set[str]
-
- channels = [x.encode('utf-8') for x in channels]
- if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma)
- # Everything fits into one command.
- self.send(command + b' ' + b','.join(channels))
- return
-
- # List too long, need to split.
- limit = 510 - len(command)
- lengths = [1 + len(x) for x in channels] # separator + channel name
- chanLengthAcceptable = [l <= limit for l in lengths]
- if not all(chanLengthAcceptable):
- # There are channel names that are too long to even fit into one message on their own; filter them out and warn about them.
- # This should never happen since the config reader would already filter it out.
- tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a]
- channels = [x for x, a in zip(channels, chanLengthAcceptable) if a]
- lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a]
- for channel in tooLongChannels:
- self.logger.warning(f'Cannot {command} {channel}: name too long')
- runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators
- offset = 0
- while channels:
- i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1)
- if i == -1: # Last batch
- i = len(channels)
- self.send(command + b' ' + b','.join(channels[:i]))
- offset = runningLengths[i-1]
- channels = channels[i:]
- runningLengths = runningLengths[i:]
-
- def update_channels(self, channels: set):
- channelsToPart = self.channels - channels
- channelsToJoin = channels - self.channels
- self.channels = channels
-
- if self.connected:
- if channelsToPart:
- self._send_join_part(b'PART', channelsToPart)
- if channelsToJoin:
- self._send_join_part(b'JOIN', channelsToJoin)
-
- def send(self, data):
- self.logger.debug(f'Send: {data!r}')
- if len(data) > 510:
- raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
- self.transport.write(data + b'\r\n')
-
- async def _get_message(self):
- self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
- messageFuture = asyncio.create_task(self.messageQueue.get())
- done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if self.connectionClosedEvent.is_set():
- if messageFuture in pending:
- self.logger.debug('Cancelling messageFuture')
- messageFuture.cancel()
- try:
- await messageFuture
- except asyncio.CancelledError:
- self.logger.debug('Cancelled messageFuture')
- pass
- else:
- # messageFuture is already done but we're stopping, so put the result back onto the queue
- self.messageQueue.putleft_nowait(messageFuture.result())
- return None, None
- assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
- return messageFuture.result()
-
- async def send_messages(self):
- while self.connected:
- self.logger.debug(f'Trying to get a message')
- channel, message = await self._get_message()
- self.logger.debug(f'Got message: {message!r}')
- if message is None:
- break
- channelB = channel.encode('utf-8')
- messageB = message.encode('utf-8')
- usermaskPrefixLength = 1 + (len(self.usermask) if self.usermask else 100) + 1
- if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
- self.logger.debug(f'Splitting up into smaller messages')
- # Message too long, need to split. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
- prefix = b'PRIVMSG ' + channelB + b' :'
- prefixLength = usermaskPrefixLength + len(prefix) # Need to account for the origin prefix included by the ircd when sending to others
- maxMessageLength = 510 - prefixLength # maximum length of the message part within each line
- messages = []
- while message:
- if len(messageB) <= maxMessageLength:
- messages.append(message)
- break
-
- spacePos = messageB.rfind(b' ', 0, maxMessageLength + 1)
- if spacePos != -1:
- messages.append(messageB[:spacePos].decode('utf-8'))
- messageB = messageB[spacePos + 1:]
- message = messageB.decode('utf-8')
- continue
-
- # No space found, need to search for a suitable codepoint location.
- pMessage = message[:maxMessageLength] # at most 510 codepoints which expand to at least 510 bytes
- pLengths = [len(x.encode('utf-8')) for x in pMessage] # byte size of each codepoint
- pRunningLengths = list(itertools.accumulate(pLengths)) # byte size up to each codepoint
- if pRunningLengths[-1] <= maxMessageLength: # Special case: entire pMessage is short enough
- messages.append(pMessage)
- message = message[maxMessageLength:]
- messageB = message.encode('utf-8')
- continue
- cutoffIndex = next(x[0] for x in enumerate(pRunningLengths) if x[1] > maxMessageLength)
- messages.append(message[:cutoffIndex])
- message = message[cutoffIndex:]
- messageB = message.encode('utf-8')
- for msg in reversed(messages):
- self.messageQueue.putleft_nowait((channel, msg))
- else:
- self.logger.info(f'Sending {message!r} to {channel!r}')
- self.unconfirmedMessages.append((channel, message))
- self.send(b'PRIVMSG ' + channelB + b' :' + messageB)
- await asyncio.sleep(1) # Rate limit
-
- async def confirm_messages(self):
- while self.connected:
- await asyncio.wait((asyncio.sleep(60), self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) # Confirm once per minute
- if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.unconfirmedMessages = []
- break
- if not self.unconfirmedMessages:
- self.logger.debug('No messages to confirm')
- continue
- self.logger.debug('Trying to confirm message delivery')
- self.pongReceivedEvent.clear()
- self.send(b'PING :42')
- await asyncio.wait((asyncio.sleep(5), self.pongReceivedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}')
- if not self.pongReceivedEvent.is_set():
- # No PONG received in five seconds, assume connection's dead
- self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue')
- self.messageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.transport.close()
- self.unconfirmedMessages = []
-
- def data_received(self, data):
- self.logger.debug(f'Data received: {data!r}')
- # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
- # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
- # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
- messages = data.split(b'\r\n')
- if self.buffer:
- self.message_received(self.buffer + messages[0])
- messages = messages[1:]
- for message in messages[:-1]:
- self.message_received(message)
- self.buffer = messages[-1]
-
- def message_received(self, message):
- self.logger.debug(f'Message received: {message!r}')
- rawMessage = message
- if message.startswith(b':') and b' ' in message:
- # Prefixed message, extract command + parameters (the prefix cannot contain a space)
- message = message.split(b' ', 1)[1]
-
- # PING/PONG
- if message.startswith(b'PING '):
- self.send(b'PONG ' + message[5:])
- elif message.startswith(b'PONG '):
- self.pongReceivedEvent.set()
-
- # SASL
- elif message.startswith(b'CAP ') and self.sasl:
- if message[message.find(b' ', 4) + 1:] == b'ACK :sasl':
- self.send(b'AUTHENTICATE EXTERNAL')
- else:
- self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection')
- self.transport.close()
- elif message == b'AUTHENTICATE +':
- self.send(b'AUTHENTICATE +')
- elif message.startswith(b'900 '): # "You are now logged in", includes the usermask
- words = message.split(b' ')
- if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]:
- if b'!~' not in words[2]:
- # At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde.
- words[2] = words[2].replace(b'!', b'!~', 1)
- self._maybe_set_usermask(words[2])
- elif message.startswith(b'903 '): # SASL auth successful
- self.authenticated = True
- self.send(b'CAP END')
- elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')):
- self.logger.error('SASL error, terminating connection')
- self.transport.close()
-
- # NICK errors
- elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')):
- self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
- self.transport.close()
-
- # USER errors
- elif any(message.startswith(x) for x in (b'461 ', b'462 ')):
- self.logger.error(f'Failed to register: {message!r}, terminating connection')
- self.transport.close()
-
- # JOIN errors
- elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')):
- self.logger.error(f'Failed to join channel: {message!r}, terminating connection')
- self.transport.close()
-
- # PART errors
- elif message.startswith(b'442 '):
- self.logger.error(f'Failed to part channel: {message!r}')
-
- # JOIN/PART errors
- elif message.startswith(b'403 '):
- self.logger.error(f'Failed to join or part channel: {message!r}')
-
- # PRIVMSG errors
- elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')):
- self.logger.error(f'Failed to send message: {message!r}')
-
- # Connection registration reply
- elif message.startswith(b'001 '):
- self.logger.info('IRC connection registered')
- if self.sasl and not self.authenticated:
- self.logger.error('IRC connection registered but not authenticated, terminating connection')
- self.transport.close()
- return
- self._send_join_part(b'JOIN', self.channels)
- asyncio.create_task(self.send_messages())
- asyncio.create_task(self.confirm_messages())
-
- # JOIN success
- elif message.startswith(b'JOIN ') and not self.usermask:
- # If this is my own join message, it should contain the usermask in the prefix
- if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage:
- usermask = rawMessage.split(b' ', 1)[0][1:]
- self._maybe_set_usermask(usermask)
-
- # Services host change
- elif message.startswith(b'396 '):
- words = message.split(b' ')
- if len(words) >= 3:
- # Sanity check inspired by irssi src/irc/core/irc-servers.c
- if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-':
- if b'@' in words[2]: # user@host
- self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2])
- else: # host (get user from previous mask or settings)
- if self.usermask:
- user = self.usermask.split(b'@')[0].split(b'!')[1]
- else:
- user = b'~' + self.config['irc']['nick'].encode('utf-8')
- self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2])
-
- def connection_lost(self, exc):
- self.logger.info('IRC connection lost')
- self.connected = False
- self.connectionClosedEvent.set()
-
-
- class IRCClient:
- logger = logging.getLogger('http2irc.IRCClient')
-
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
-
- self._transport = None
- self._protocol = None
-
- def update_config(self, config):
- needReconnect = self.config['irc'] != config['irc']
- self.config = config
- if self._transport: # if currently connected:
- if needReconnect:
- self._transport.close()
- else:
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
- self._protocol.update_channels(self.channels)
-
- def _get_ssl_context(self):
- ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
- if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
- if ctx is True:
- ctx = ssl.create_default_context()
- if isinstance(ctx, ssl.SSLContext):
- ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
- return ctx
-
- async def run(self, loop, sigintEvent):
- connectionClosedEvent = asyncio.Event()
- while True:
- connectionClosedEvent.clear()
- try:
- self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
- try:
- await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- finally:
- self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
- except (ConnectionRefusedError, asyncio.TimeoutError) as e:
- self.logger.error(str(e))
- await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- if sigintEvent.is_set():
- break
-
-
- class WebServer:
- logger = logging.getLogger('http2irc.WebServer')
-
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
-
- self._paths = {} # '/path' => ('#channel', auth, module, moduleargs) where auth is either False (no authentication) or the HTTP header value for basic auth
-
- self._app = aiohttp.web.Application()
- self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)])
-
- self.update_config(config)
- self._configChanged = asyncio.Event()
-
- def update_config(self, config):
- self._paths = {map_['webpath']: (map_['ircchannel'], f'Basic {base64.b64encode(map_["auth"].encode("utf-8")).decode("utf-8")}' if map_['auth'] else False, map_['module'], map_['moduleargs']) for map_ in config['maps'].values()}
- needRebind = self.config['web'] != config['web']
- self.config = config
- if needRebind:
- self._configChanged.set()
-
- async def run(self, stopEvent):
- while True:
- runner = aiohttp.web.AppRunner(self._app)
- await runner.setup()
- site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
- await site.start()
- await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
- await runner.cleanup()
- if stopEvent.is_set():
- break
- self._configChanged.clear()
-
- async def post(self, request):
- self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.path!r}')
- try:
- channel, auth, module, moduleargs = self._paths[request.path]
- except KeyError:
- self.logger.info(f'Bad request {id(request)}: no path {request.path!r}')
- raise aiohttp.web.HTTPNotFound()
- if auth:
- authHeader = request.headers.get('Authorization')
- if not authHeader or authHeader != auth:
- self.logger.info(f'Bad request {id(request)}: authentication failed: {authHeader!r} != {auth}')
- raise aiohttp.web.HTTPForbidden()
- if module is not None:
- self.logger.debug(f'Processing request {id(request)} using {module!r}')
- try:
- message = await module.process(request, *moduleargs)
- except aiohttp.web.HTTPException as e:
- raise e
- except Exception as e:
- self.logger.error(f'Bad request {id(request)}: exception in module process function: {e!s}')
- raise aiohttp.web.HTTPBadRequest()
- if '\r' in message or '\n' in message:
- self.logger.error(f'Bad request {id(request)}: module process function returned message with linebreaks: {message!r}')
- raise aiohttp.web.HTTPBadRequest()
- else:
- self.logger.debug(f'Processing request {id(request)} using default processor')
- message = await self._default_process(request)
- self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
- self.messageQueue.put_nowait((channel, message))
- raise aiohttp.web.HTTPOk()
-
- async def _default_process(self, request):
- try:
- message = await request.text()
- except Exception as e:
- self.logger.info(f'Bad request {id(request)}: exception while reading request data: {e!s}')
- raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-)
- self.logger.debug(f'Request {id(request)} payload: {message!r}')
- # Strip optional [CR] LF at the end of the payload
- if message.endswith('\r\n'):
- message = message[:-2]
- elif message.endswith('\n'):
- message = message[:-1]
- if '\r' in message or '\n' in message:
- self.logger.info('Bad request {id(request)}: linebreaks in message')
- raise aiohttp.web.HTTPBadRequest()
- return message
-
-
- def configure_logging(config):
- #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+)
- root = logging.getLogger()
- root.setLevel(getattr(logging, config['logging']['level']))
- root.handlers = [] #FIXME: Undocumented attribute of logging.Logger
- formatter = logging.Formatter(config['logging']['format'], style = '{')
- stderrHandler = logging.StreamHandler()
- stderrHandler.setFormatter(formatter)
- root.addHandler(stderrHandler)
-
-
- async def main():
- if len(sys.argv) != 2:
- print('Usage: http2irc.py CONFIGFILE', file = sys.stderr)
- sys.exit(1)
- configFile = sys.argv[1]
- config = Config(configFile)
- configure_logging(config)
-
- loop = asyncio.get_running_loop()
-
- messageQueue = MessageQueue()
-
- irc = IRCClient(messageQueue, config)
- webserver = WebServer(messageQueue, config)
-
- sigintEvent = asyncio.Event()
- def sigint_callback():
- global logger
- nonlocal sigintEvent
- logger.info('Got SIGINT, stopping')
- sigintEvent.set()
- loop.add_signal_handler(signal.SIGINT, sigint_callback)
-
- def sigusr1_callback():
- global logger
- nonlocal config, irc, webserver
- logger.info('Got SIGUSR1, reloading config')
- try:
- newConfig = config.reread()
- except InvalidConfig as e:
- logger.error(f'Config reload failed: {e!s} (old config remains active)')
- return
- config = newConfig
- configure_logging(config)
- irc.update_config(config)
- webserver.update_config(config)
- loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)
-
- await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent))
-
-
- if __name__ == '__main__':
- asyncio.run(main())
|