|
- import aiohttp
- import aiohttp.web
- import asyncio
- import base64
- import collections
- import functools
- import importlib.util
- import inspect
- import ircstates
- import irctokens
- import itertools
- import json
- import logging
- import math
- import os.path
- import signal
- import socket
- import ssl
- import string
- import sys
- import time
- import toml
- import warnings
-
-
- logger = logging.getLogger('http2irc')
- SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
-
-
- class InvalidConfig(Exception):
- '''Error in configuration file'''
-
-
- def is_valid_pem(path, withCert):
- '''Very basic check whether something looks like a valid PEM certificate'''
- try:
- with open(path, 'rb') as fp:
- contents = fp.read()
-
- # All of these raise exceptions if something's wrong...
- if withCert:
- assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n')
- endCertPos = contents.index(b'-----END CERTIFICATE-----\n')
- base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True)
- assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n')
- else:
- assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n')
- endCertPos = -26 # Please shoot me.
- endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n')
- base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True)
- assert contents[endKeyPos + 26:] == b''
- return True
- except: # Yes, really
- return False
-
-
- async def wait_cancel_pending(aws, paws = None, **kwargs):
- '''asyncio.wait but with automatic cancellation of non-completed tasks. Tasks in paws (persistent awaitables) are not automatically cancelled.'''
- if paws is None:
- paws = set()
- tasks = aws | paws
- logger.debug(f'waiting for {tasks!r}')
- done, pending = await asyncio.wait(tasks, **kwargs)
- logger.debug(f'done waiting for {tasks!r}; cancelling pending non-persistent tasks: {pending!r}')
- for task in pending:
- if task not in paws:
- logger.debug(f'cancelling {task!r}')
- task.cancel()
- logger.debug(f'awaiting cancellation of {task!r}')
- try:
- await task
- except asyncio.CancelledError:
- pass
- logger.debug(f'done cancelling {task!r}')
- logger.debug(f'done wait_cancel_pending {tasks!r}')
- return done, pending
-
-
- class Config(dict):
- def __init__(self, filename):
- super().__init__()
- self._filename = filename
-
- with open(self._filename, 'r') as fp:
- obj = toml.load(fp)
-
- # Sanity checks
- if any(x not in ('logging', 'irc', 'web', 'maps') for x in obj.keys()):
- raise InvalidConfig('Unknown sections found in base object')
- if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
- raise InvalidConfig('Invalid section type(s), expected objects/dicts')
- if 'logging' in obj:
- if any(x not in ('level', 'format') for x in obj['logging']):
- raise InvalidConfig('Unknown key found in log section')
- if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'):
- raise InvalidConfig('Invalid log level')
- if 'format' in obj['logging']:
- if not isinstance(obj['logging']['format'], str):
- raise InvalidConfig('Invalid log format')
- try:
- #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion)
- # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text).
- assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0
- except (ValueError, AssertionError) as e:
- raise InvalidConfig('Invalid log format: parsing failed') from e
- if 'irc' in obj:
- if any(x not in ('host', 'port', 'ssl', 'family', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
- raise InvalidConfig('Unknown key found in irc section')
- if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
- raise InvalidConfig('Invalid IRC host')
- if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
- raise InvalidConfig('Invalid IRC port')
- if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
- raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
- if 'family' in obj['irc']:
- if obj['irc']['family'] not in ('inet', 'INET', 'inet6', 'INET6'):
- raise InvalidConfig('Invalid IRC family')
- obj['irc']['family'] = getattr(socket, f'AF_{obj["irc"]["family"].upper()}')
- if 'nick' in obj['irc']:
- if not isinstance(obj['irc']['nick'], str) or not obj['irc']['nick']:
- raise InvalidConfig('Invalid IRC nick')
- if obj['irc']['nick'][0] not in string.ascii_letters + '[]\\`_^{|}' or obj['irc']['nick'].strip(string.ascii_letters + string.digits + '[]\\`_^{|}') != '':
- # The allowed characters in nicknames (per RFC 2812) are a strict subset of the ones for usernames, so no need to also check for the latter.
- raise InvalidConfig('Invalid IRC nick: contains illegal characters')
- if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
- raise InvalidConfig('Invalid IRC nick: NICK command too long')
- if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
- raise InvalidConfig('Invalid IRC realname')
- if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510:
- raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long')
- if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']):
- raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile')
- if 'certfile' in obj['irc']:
- if not isinstance(obj['irc']['certfile'], str):
- raise InvalidConfig('Invalid certificate file: not a string')
- obj['irc']['certfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certfile']))
- if not os.path.isfile(obj['irc']['certfile']):
- raise InvalidConfig('Invalid certificate file: not a regular file')
- if not is_valid_pem(obj['irc']['certfile'], True):
- raise InvalidConfig('Invalid certificate file: not a valid PEM cert')
- if 'certkeyfile' in obj['irc']:
- if not isinstance(obj['irc']['certkeyfile'], str):
- raise InvalidConfig('Invalid certificate key file: not a string')
- obj['irc']['certkeyfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certkeyfile']))
- if not os.path.isfile(obj['irc']['certkeyfile']):
- raise InvalidConfig('Invalid certificate key file: not a regular file')
- if not is_valid_pem(obj['irc']['certkeyfile'], False):
- raise InvalidConfig('Invalid certificate key file: not a valid PEM key')
- if 'web' in obj:
- if any(x not in ('host', 'port', 'maxrequestsize') for x in obj['web']):
- raise InvalidConfig('Unknown key found in web section')
- if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
- raise InvalidConfig('Invalid web hostname')
- if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
- raise InvalidConfig('Invalid web port')
- if 'maxrequestsize' in obj['web'] and (not isinstance(obj['web']['maxrequestsize'], int) or obj['web']['maxrequestsize'] <= 0):
- raise InvalidConfig('Invalid web maxrequestsize')
- if 'maps' in obj:
- seenWebPaths = {}
- for key, map_ in obj['maps'].items():
- if not isinstance(key, str) or not key:
- raise InvalidConfig(f'Invalid map key {key!r}')
- if not isinstance(map_, collections.abc.Mapping):
- raise InvalidConfig(f'Invalid map for {key!r}')
- if any(x not in ('webpath', 'ircchannel', 'auth', 'postauth', 'getauth', 'module', 'moduleargs', 'overlongmode') for x in map_):
- raise InvalidConfig(f'Unknown key(s) found in map {key!r}')
-
- if 'webpath' not in map_:
- map_['webpath'] = f'/{key}'
- if not isinstance(map_['webpath'], str):
- raise InvalidConfig(f'Invalid map {key!r} web path: not a string')
- if not map_['webpath'].startswith('/'):
- raise InvalidConfig(f'Invalid map {key!r} web path: does not start at the root')
- if map_['webpath'] == '/status':
- raise InvalidConfig(f'Invalid map {key!r} web path: cannot be "/status"')
- if map_['webpath'] in seenWebPaths:
- raise InvalidConfig(f'Invalid map {key!r} web path: collides with map {seenWebPaths[map_["webpath"]]!r}')
- seenWebPaths[map_['webpath']] = key
-
- if 'ircchannel' not in map_:
- map_['ircchannel'] = f'#{key}'
- if not isinstance(map_['ircchannel'], str):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: not a string')
- if not map_['ircchannel'].startswith('#') and not map_['ircchannel'].startswith('&'):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: does not start with # or &')
- if any(x in map_['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')):
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: contains forbidden characters')
- if len(map_['ircchannel']) > 200:
- raise InvalidConfig(f'Invalid map {key!r} IRC channel: too long')
-
- # For backward compatibility, 'auth' gets treated as 'postauth'
- if 'auth' in map_:
- if 'postauth' in map_:
- raise InvalidConfig(f'auth and postauth are aliases and cannot be used together')
- warnings.warn('auth is deprecated, use postauth instead', DeprecationWarning)
- map_['postauth'] = map_['auth']
- del map_['auth']
- for k in ('postauth', 'getauth'):
- if k not in map_:
- continue
- if map_[k] is not False and not isinstance(map_[k], str):
- raise InvalidConfig(f'Invalid map {key!r} {k}: must be false or a string')
- if isinstance(map_[k], str) and ':' not in map_[k]:
- raise InvalidConfig(f'Invalid map {key!r} {k}: must contain a colon')
-
- if 'module' in map_:
- # If the path is relative, try to evaluate it relative to either the config file or this file; some modules are in the repo, but this also allows overriding them.
- for basePath in (os.path.dirname(self._filename), os.path.dirname(__file__)):
- if os.path.isfile(os.path.join(basePath, map_['module'])):
- map_['module'] = os.path.abspath(os.path.join(basePath, map_['module']))
- break
- else:
- raise InvalidConfig(f'Module {map_["module"]!r} in map {key!r} is not a file')
- if 'moduleargs' in map_:
- if not isinstance(map_['moduleargs'], list):
- raise InvalidConfig(f'Invalid module args for {key!r}: not an array')
- if 'module' not in map_:
- raise InvalidConfig(f'Module args cannot be specified without a module for {key!r}')
- if 'overlongmode' in map_:
- if not isinstance(map_['overlongmode'], str):
- raise InvalidConfig(f'Invalid map {key!r} overlongmode: not a string')
- if map_['overlongmode'] not in ('split', 'truncate'):
- raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value')
-
- # Default values
- finalObj = {
- 'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'},
- 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'family': 0, 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None},
- 'web': {'host': '127.0.0.1', 'port': 8080, 'maxrequestsize': 1048576},
- 'maps': {}
- }
-
- # Fill in default values for the maps
- for key, map_ in obj['maps'].items():
- # webpath is already set above for duplicate checking
- # ircchannel is set above for validation
- if 'postauth' not in map_:
- map_['postauth'] = False
- if 'getauth' not in map_:
- map_['getauth'] = False
- if 'module' not in map_:
- map_['module'] = None
- if 'moduleargs' not in map_:
- map_['moduleargs'] = []
- if 'overlongmode' not in map_:
- map_['overlongmode'] = 'split'
-
- # Load modules
- modulePaths = {} # path: str -> (extraargs: int, key: str)
- for key, map_ in obj['maps'].items():
- if map_['module'] is not None:
- if map_['module'] not in modulePaths:
- modulePaths[map_['module']] = (len(map_['moduleargs']), key)
- elif modulePaths[map_['module']][0] != len(map_['moduleargs']):
- raise InvalidConfig(f'Module {map_["module"]!r} process function extra argument inconsistency between {key!r} and {modulePaths[map_["module"]][1]!r}')
- modules = {} # path: str -> module: module
- for i, (path, (extraargs, _)) in enumerate(modulePaths.items()):
- try:
- # Build a name that is virtually guaranteed to be unique across a process.
- # Although importlib does not seem to perform any caching as of CPython 3.8, this is not guaranteed by spec.
- spec = importlib.util.spec_from_file_location(f'http2irc-module-{id(self)}-{i}', path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- except Exception as e: # This is ugly, but exec_module can raise virtually any exception
- raise InvalidConfig(f'Loading module {path!r} failed: {e!s}')
- if not hasattr(module, 'process'):
- raise InvalidConfig(f'Module {path!r} does not have a process function')
- if not inspect.iscoroutinefunction(module.process):
- raise InvalidConfig(f'Module {path!r} process attribute is not a coroutine function')
- nargs = len(inspect.signature(module.process).parameters)
- if nargs != 1 + extraargs:
- raise InvalidConfig(f'Module {path!r} process function takes {nargs} parameter{"s" if nargs > 1 else ""}, not {1 + extraargs}')
- modules[path] = module
-
- # Replace module value in maps
- for map_ in obj['maps'].values():
- if 'module' in map_ and map_['module'] is not None:
- map_['module'] = modules[map_['module']]
-
- # Merge in what was read from the config file and set keys on self
- for key in ('logging', 'irc', 'web', 'maps'):
- if key in obj:
- finalObj[key].update(obj[key])
- self[key] = finalObj[key]
-
- def __repr__(self):
- return f'<Config(logging={self["logging"]!r}, irc={self["irc"]!r}, web={self["web"]!r}, maps={self["maps"]!r})>'
-
- def reread(self):
- return Config(self._filename)
-
-
- class MessageQueue:
- # An object holding onto the messages received over HTTP for sending to IRC
- # This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
- # Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
- # Differences to asyncio.Queue include:
- # - No maxsize
- # - No put coroutine (not necessary since the queue can never be full)
- # - Only one concurrent getter
- # - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)
-
- logger = logging.getLogger('http2irc.MessageQueue')
-
- def __init__(self):
- self._getter = None # None | asyncio.Future
- self._queue = collections.deque()
-
- async def get(self):
- if self._getter is not None:
- raise RuntimeError('Cannot get concurrently')
- if len(self._queue) == 0:
- self._getter = asyncio.get_running_loop().create_future()
- self.logger.debug('Awaiting getter')
- try:
- await self._getter
- except asyncio.CancelledError:
- self.logger.debug('Cancelled getter')
- self._getter = None
- raise
- self.logger.debug('Awaited getter')
- self._getter = None
- # For testing the cancellation/putting back onto the queue
- #self.logger.debug('Delaying message queue get')
- #await asyncio.sleep(3)
- #self.logger.debug('Done delaying')
- return self.get_nowait()
-
- def get_nowait(self):
- if len(self._queue) == 0:
- raise asyncio.QueueEmpty
- return self._queue.popleft()
-
- def put_nowait(self, item):
- self._queue.append(item)
- if self._getter is not None and not self._getter.done() and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def putleft_nowait(self, *item):
- self._queue.extendleft(reversed(item))
- if self._getter is not None and not self._getter.done() and not self._getter.cancelled():
- self._getter.set_result(None)
-
- def qsize(self):
- return len(self._queue)
-
-
- class IRCClientProtocol(asyncio.Protocol):
- logger = logging.getLogger('http2irc.IRCClientProtocol')
-
- def __init__(self, http2ircMessageQueue, irc2httpBroadcaster, connectionClosedEvent, loop, config, channels):
- self.http2ircMessageQueue = http2ircMessageQueue
- self.irc2httpBroadcaster = irc2httpBroadcaster
- self.connectionClosedEvent = connectionClosedEvent
- self.loop = loop
- self.config = config
- self.lastRecvTime = None
- self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit
- self.sendQueue = asyncio.Queue()
- self.buffer = b''
- self.connected = False
- self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
- self.kickedChannels = set() # Channels the bot got KICKed from (for re-INVITE purposes; reset on config reloading)
- self.unconfirmedMessages = []
- self.pongReceivedEvent = asyncio.Event()
- self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
- self.authenticated = False
- self.server = ircstates.Server(self.config['irc']['host'])
- self.capReqsPending = set() # Capabilities requested from the server but not yet ACKd or NAKd
- self.caps = set() # Capabilities acknowledged by the server
- self.whoxQueue = collections.deque() # Names of channels that were joined successfully but for which no WHO (WHOX) query was sent yet
- self.whoxChannel = None # Name of channel for which a WHO query is currently running
- self.whoxReply = [] # List of (nickname, account) tuples from the currently running WHO query
- self.whoxStartTime = None
- self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}
-
- @staticmethod
- def nick_command(nick: str):
- return b'NICK ' + nick.encode('utf-8')
-
- @staticmethod
- def user_command(nick: str, real: str):
- nickb = nick.encode('utf-8')
- return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')
-
- def connection_made(self, transport):
- self.logger.info('IRC connected')
- self.transport = transport
- self.connected = True
- caps = [b'multi-prefix', b'userhost-in-names', b'away-notify', b'account-notify', b'extended-join']
- if self.sasl:
- caps.append(b'sasl')
- for cap in caps:
- self.capReqsPending.add(cap.decode('ascii'))
- self.send(b'CAP REQ :' + cap)
- self.send(self.nick_command(self.config['irc']['nick']))
- self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))
-
- def _send_join_part(self, command, channels):
- '''Split a JOIN or PART into multiple messages as necessary'''
- # command: b'JOIN' or b'PART'; channels: set[str]
-
- channels = [x.encode('utf-8') for x in channels]
- if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma)
- # Everything fits into one command.
- self.send(command + b' ' + b','.join(channels))
- return
-
- # List too long, need to split.
- limit = 510 - len(command)
- lengths = [1 + len(x) for x in channels] # separator + channel name
- chanLengthAcceptable = [l <= limit for l in lengths]
- if not all(chanLengthAcceptable):
- # There are channel names that are too long to even fit into one message on their own; filter them out and warn about them.
- # This should never happen since the config reader would already filter it out.
- tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a]
- channels = [x for x, a in zip(channels, chanLengthAcceptable) if a]
- lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a]
- for channel in tooLongChannels:
- self.logger.warning(f'Cannot {command} {channel}: name too long')
- runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators
- offset = 0
- while channels:
- i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1)
- if i == -1: # Last batch
- i = len(channels)
- self.send(command + b' ' + b','.join(channels[:i]))
- offset = runningLengths[i-1]
- channels = channels[i:]
- runningLengths = runningLengths[i:]
-
- def update_channels(self, channels: set):
- channelsToPart = self.channels - channels
- channelsToJoin = channels - self.channels
- self.channels = channels
- self.kickedChannels = set()
-
- if self.connected:
- if channelsToPart:
- self._send_join_part(b'PART', channelsToPart)
- if channelsToJoin:
- self._send_join_part(b'JOIN', channelsToJoin)
-
- def send(self, data):
- self.logger.debug(f'Queueing for send: {data!r}')
- if len(data) > 510:
- raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
- self.sendQueue.put_nowait(data)
-
- def _direct_send(self, data):
- self.logger.debug(f'Send: {data!r}')
- time_ = time.time()
- self.transport.write(data + b'\r\n')
- if data.startswith(b'PRIVMSG '):
- # Send own messages to broadcaster as well
- command, channels, message = data.decode('utf-8').split(' ', 2)
- for channel in channels.split(','):
- assert channel.startswith('#') or channel.startswith('&'), f'invalid channel: {channel!r}'
- try:
- modes = self.get_mode_chars(self.server.channels[self.server.casefold(channel)].users.get(self.server.casefold(self.server.nickname)))
- except KeyError:
- # E.g. when kicked from a channel in the config
- # If the target channel isn't in self.server.channels, this *should* mean that the bot is not in the channel.
- # Therefore, don't send this to the broadcaster either.
- # TODO: Use echo-message instead and forward that to the broadcaster instead of emulating it here. Requires support from the network though...
- continue
- user = {
- 'nick': self.server.nickname,
- 'hostmask': f'{self.server.nickname}!{self.server.username}@{self.server.hostname}',
- 'account': self.server.account,
- 'modes': modes,
- }
- self.irc2httpBroadcaster.send(channel, {'time': time_, 'command': command, 'channel': channel, 'user': user, 'message': message})
- return time_
-
- async def send_queue(self):
- while True:
- self.logger.debug('Trying to get data from send queue')
- t = asyncio.create_task(self.sendQueue.get())
- done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
- if self.connectionClosedEvent.is_set():
- break
- assert t in done, f'{t!r} is not in {done!r}'
- data = t.result()
- self.logger.debug(f'Got {data!r} from send queue')
- now = time.time()
- if self.lastSentTime is not None and now - self.lastSentTime < 1:
- self.logger.debug(f'Rate limited')
- await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now)
- if self.connectionClosedEvent.is_set():
- break
- time_ = self._direct_send(data)
- if self.lastSentTime is not None:
- self.lastSentTime = time_
-
- async def _get_message(self):
- self.logger.debug(f'Message queue {id(self.http2ircMessageQueue)} length: {self.http2ircMessageQueue.qsize()}')
- messageFuture = asyncio.create_task(self.http2ircMessageQueue.get())
- done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = asyncio.FIRST_COMPLETED)
- if self.connectionClosedEvent.is_set():
- if messageFuture in pending:
- self.logger.debug('Cancelling messageFuture')
- messageFuture.cancel()
- try:
- await messageFuture
- except asyncio.CancelledError:
- self.logger.debug('Cancelled messageFuture')
- pass
- else:
- # messageFuture is already done but we're stopping, so put the result back onto the queue
- self.http2ircMessageQueue.putleft_nowait(messageFuture.result())
- return None, None, None
- assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
- return messageFuture.result()
-
- def _self_usermask_length(self):
- if not self.server.nickname or not self.server.username or not self.server.hostname:
- return 100
- return len(self.server.nickname) + 1 + len(self.server.username) + 1 + len(self.server.hostname) # nickname!username@hostname
-
- async def send_messages(self):
- while self.connected:
- self.logger.debug(f'Trying to get a message')
- channel, message, overlongmode = await self._get_message()
- self.logger.debug(f'Got message: {message!r}')
- if message is None:
- break
- channelB = channel.encode('utf-8')
- messageB = message.encode('utf-8')
- usermaskPrefixLength = 1 + self._self_usermask_length() + 1 # :usermask<SP>
- if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
- # Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
- self.logger.debug(f'Message too long, overlongmode = {overlongmode}')
- prefix = b'PRIVMSG ' + channelB + b' :'
- prefixLength = usermaskPrefixLength + len(prefix) # Need to account for the origin prefix included by the ircd when sending to others
- maxMessageLength = 510 - prefixLength # maximum length of the message part within each line
- if overlongmode == 'truncate':
- maxMessageLength -= 3 # Make room for an ellipsis at the end
- messages = []
- while message:
- if overlongmode == 'truncate' and messages:
- break # Only need the first message on truncation
- if len(messageB) <= maxMessageLength:
- messages.append(message)
- break
-
- spacePos = messageB.rfind(b' ', 0, maxMessageLength + 1)
- if spacePos != -1:
- messages.append(messageB[:spacePos].decode('utf-8'))
- messageB = messageB[spacePos + 1:]
- message = messageB.decode('utf-8')
- continue
-
- # No space found, need to search for a suitable codepoint location.
- pMessage = message[:maxMessageLength] # at most 510 codepoints which expand to at least 510 bytes
- pLengths = [len(x.encode('utf-8')) for x in pMessage] # byte size of each codepoint
- pRunningLengths = list(itertools.accumulate(pLengths)) # byte size up to each codepoint
- if pRunningLengths[-1] <= maxMessageLength: # Special case: entire pMessage is short enough
- messages.append(pMessage)
- message = message[maxMessageLength:]
- messageB = message.encode('utf-8')
- continue
- cutoffIndex = next(x[0] for x in enumerate(pRunningLengths) if x[1] > maxMessageLength)
- messages.append(message[:cutoffIndex])
- message = message[cutoffIndex:]
- messageB = message.encode('utf-8')
- if overlongmode == 'split':
- for msg in reversed(messages):
- self.http2ircMessageQueue.putleft_nowait((channel, msg, overlongmode))
- elif overlongmode == 'truncate':
- self.http2ircMessageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
- else:
- self.logger.info(f'Sending {message!r} to {channel!r}')
- self.unconfirmedMessages.append((channel, message, overlongmode))
- self.send(b'PRIVMSG ' + channelB + b' :' + messageB)
-
- async def confirm_messages(self):
- while self.connected:
- await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
- if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
- self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.unconfirmedMessages = []
- break
- if not self.unconfirmedMessages:
- self.logger.debug('No messages to confirm')
- continue
- self.logger.debug('Trying to confirm message delivery')
- self.pongReceivedEvent.clear()
- self._direct_send(b'PING :42')
- await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 5)
- self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}')
- if not self.pongReceivedEvent.is_set():
- # No PONG received in five seconds, assume connection's dead
- self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue')
- self.http2ircMessageQueue.putleft_nowait(*self.unconfirmedMessages)
- self.transport.close()
- self.unconfirmedMessages = []
-
- def data_received(self, data):
- time_ = time.time()
- self.logger.debug(f'Data received: {data!r}')
- self.lastRecvTime = time_
- # If there's any data left in the buffer, prepend it to the data. Split on CRLF.
- # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
- # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
- if self.buffer:
- data = self.buffer + data
- messages = data.split(b'\r\n')
- for message in messages[:-1]:
- lines = self.server.recv(message + b'\r\n')
- assert len(lines) == 1, f'recv did not return exactly one line: {message!r} -> {lines!r}'
- self.message_received(time_, message, lines[0])
- self.server.parse_tokens(lines[0])
- self.buffer = messages[-1]
-
- def message_received(self, time_, message, line):
- self.logger.debug(f'Message received at {time_}: {message!r}')
-
- # Send to HTTP broadcaster
- # Note: WHOX is handled further down
- for d in self.line_to_dicts(time_, line):
- self.irc2httpBroadcaster.send(d['channel'], d)
-
- maybeTriggerWhox = False
-
- # PING/PONG
- if line.command == 'PING':
- self._direct_send(irctokens.build('PONG', line.params).format().encode('utf-8'))
- elif line.command == 'PONG':
- self.pongReceivedEvent.set()
-
- # IRCv3 and SASL
- elif line.command == 'CAP':
- if line.params[1] == 'ACK':
- for cap in line.params[2].split(' '):
- self.logger.debug(f'CAP ACK: {cap}')
- self.caps.add(cap)
- if cap == 'sasl' and self.sasl:
- self.send(b'AUTHENTICATE EXTERNAL')
- else:
- self.capReqsPending.remove(cap)
- elif line.params[1] == 'NAK':
- self.logger.warning(f'Failed to activate CAP(s): {line.params[2]}')
- for cap in line.params[2].split(' '):
- self.capReqsPending.remove(cap)
- if len(self.capReqsPending) == 0:
- self.send(b'CAP END')
- elif line.command == 'AUTHENTICATE' and line.params == ['+']:
- self.send(b'AUTHENTICATE +')
- elif line.command == ircstates.numerics.RPL_SASLSUCCESS:
- self.authenticated = True
- self.capReqsPending.remove('sasl')
- if len(self.capReqsPending) == 0:
- self.send(b'CAP END')
- elif line.command in ('902', ircstates.numerics.ERR_SASLFAIL, ircstates.numerics.ERR_SASLTOOLONG, ircstates.numerics.ERR_SASLABORTED, ircstates.numerics.RPL_SASLMECHS):
- self.logger.error('SASL error, terminating connection')
- self.transport.close()
-
- # NICK errors
- elif line.command in ('431', ircstates.numerics.ERR_ERRONEUSNICKNAME, ircstates.numerics.ERR_NICKNAMEINUSE, '436'):
- self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
- self.transport.close()
-
- # USER errors
- elif line.command in ('461', '462'):
- self.logger.error(f'Failed to register: {message!r}, terminating connection')
- self.transport.close()
-
- # JOIN errors
- elif line.command in (
- ircstates.numerics.ERR_TOOMANYCHANNELS,
- ircstates.numerics.ERR_CHANNELISFULL,
- ircstates.numerics.ERR_INVITEONLYCHAN,
- ircstates.numerics.ERR_BANNEDFROMCHAN,
- ircstates.numerics.ERR_BADCHANNELKEY,
- ):
- self.logger.error(f'Failed to join channel {line.params[1]}: {message!r}')
- errChannel = self.server.casefold(line.params[1])
- for channel in self.channels:
- if self.server.casefold(channel) == errChannel:
- self.channels.remove(channel)
- break
-
- # PART errors
- elif line.command == '442':
- self.logger.error(f'Failed to part channel: {message!r}')
-
- # JOIN/PART errors
- elif line.command == ircstates.numerics.ERR_NOSUCHCHANNEL:
- self.logger.error(f'Failed to join or part channel: {message!r}')
-
- # PRIVMSG errors
- elif line.command in (ircstates.numerics.ERR_NOSUCHNICK, '404', '407', '411', '412', '413', '414'):
- self.logger.error(f'Failed to send message: {message!r}')
-
- # Connection registration reply
- elif line.command == ircstates.numerics.RPL_WELCOME:
- self.logger.info('IRC connection registered')
- if self.sasl and not self.authenticated:
- self.logger.error('IRC connection registered but not authenticated, terminating connection')
- self.transport.close()
- return
- self.lastSentTime = time.time()
- self._send_join_part(b'JOIN', self.channels)
- asyncio.create_task(self.send_messages())
- asyncio.create_task(self.confirm_messages())
-
- # Bot getting KICKed
- elif line.command == 'KICK' and line.source and self.server.casefold(line.params[1]) == self.server.casefold(self.server.nickname):
- self.logger.warning(f'Got kicked from {line.params[0]}')
- kickedChannel = self.server.casefold(line.params[0])
- for channel in self.channels:
- if self.server.casefold(channel) == kickedChannel:
- self.channels.remove(channel)
- self.kickedChannels.add(channel) # Non-folded version so the set comparison in update_channels doesn't break.
- break
-
- # Bot getting INVITEd after a KICK
- elif line.command == 'INVITE' and line.source and self.server.casefold(line.params[0]) == self.server.casefold(self.server.nickname):
- invitedChannel = self.server.casefold(line.params[1])
- for channel in self.kickedChannels:
- if self.server.casefold(channel) == invitedChannel:
- self.channels.add(channel)
- self.kickedChannels.remove(channel)
- self._send_join_part(b'JOIN', {channel})
- break
-
- # WHOX on successful JOIN if supported to fetch account information
- elif line.command == 'JOIN' and self.server.isupport.whox and line.source and self.server.casefold(line.hostmask.nickname) == self.server.casefold(self.server.nickname):
- self.whoxQueue.extend(line.params[0].split(','))
- maybeTriggerWhox = True
-
- # WHOX response
- elif line.command == ircstates.numerics.RPL_WHOSPCRPL and line.params[1] == '042':
- self.whoxReply.append({'nick': line.params[4], 'hostmask': f'{line.params[4]}!{line.params[2]}@{line.params[3]}', 'account': line.params[5] if line.params[5] != '0' else None})
-
- # End of WHOX response
- elif line.command == ircstates.numerics.RPL_ENDOFWHO:
- # Patch ircstates account info; ircstates does not parse the WHOX reply itself.
- for entry in self.whoxReply:
- if entry['account']:
- self.server.users[self.server.casefold(entry['nick'])].account = entry['account']
- self.irc2httpBroadcaster.send(self.whoxChannel, {'time': time_, 'command': 'RPL_ENDOFWHO', 'channel': self.whoxChannel, 'users': self.whoxReply, 'whoxstarttime': self.whoxStartTime})
- self.whoxChannel = None
- self.whoxReply = []
- self.whoxStartTime = None
- maybeTriggerWhox = True
-
- # General fatal ERROR
- elif line.command == 'ERROR':
- self.logger.error(f'Server sent ERROR: {message!r}')
- self.transport.close()
-
- # Send next WHOX if appropriate
- if maybeTriggerWhox and self.whoxChannel is None and self.whoxQueue:
- self.whoxChannel = self.whoxQueue.popleft()
- self.whoxReply = []
- self.whoxStartTime = time.time() # Note, may not be the actual start time due to rate limiting
- self.send(b'WHO ' + self.whoxChannel.encode('utf-8') + b' c%tuhna,042')
-
- def get_mode_chars(self, channelUser):
- if channelUser is None:
- return ''
- prefix = self.server.isupport.prefix
- return ''.join(prefix.prefixes[i] for i in sorted((prefix.modes.index(c) for c in channelUser.modes if c in prefix.modes)))
-
- def line_to_dicts(self, time_, line):
- if line.source:
- sourceUser = self.server.users.get(self.server.casefold(line.hostmask.nickname)) if line.source else None
- get_modes = lambda channel, nick = line.hostmask.nickname: self.get_mode_chars(self.server.channels[self.server.casefold(channel)].users.get(self.server.casefold(nick)))
- get_user = lambda channel, withModes = True: {
- 'nick': line.hostmask.nickname,
- 'hostmask': str(line.hostmask),
- 'account': getattr(self.server.users.get(self.server.casefold(line.hostmask.nickname)), 'account', None),
- **({'modes': get_modes(channel)} if withModes else {}),
- }
- if line.command == 'JOIN':
- # Although servers SHOULD NOT send multiple channels in one message per the modern IRC docs <https://modern.ircdocs.horse/#join-message>, let's do the safe thing...
- account = {'account': line.params[-2] if line.params[-2] != '*' else None} if 'extended-join' in self.caps else {}
- for channel in line.params[0].split(','):
- # There can't be a mode set yet on the JOIN, so no need to use get_modes (which would complicate the self-join).
- yield {'time': time_, 'command': 'JOIN', 'channel': channel, 'user': {**get_user(channel, False), **account}}
- elif line.command in ('PRIVMSG', 'NOTICE'):
- channel = line.params[0]
- if channel not in self.server.channels:
- return
- if line.command == 'PRIVMSG' and line.params[1].startswith('\x01ACTION ') and line.params[1].endswith('\x01'):
- # CTCP ACTION (aka /me)
- yield {'time': time_, 'command': 'ACTION', 'channel': channel, 'user': get_user(channel), 'message': line.params[1][8:-1]}
- return
- yield {'time': time_, 'command': line.command, 'channel': channel, 'user': get_user(channel), 'message': line.params[1]}
- elif line.command == 'PART':
- for channel in line.params[0].split(','):
- yield {'time': time_, 'command': 'PART', 'channel': channel, 'user': get_user(channel), 'reason': line.params[1] if len(line.params) == 2 else None}
- elif line.command in ('QUIT', 'NICK', 'ACCOUNT'):
- if line.hostmask.nickname == self.server.nickname:
- channels = self.channels
- elif sourceUser is not None:
- channels = sourceUser.channels
- else:
- return
- for channel in channels:
- if line.command == 'QUIT':
- extra = {'reason': line.params[0] if len(line.params) == 1 else None}
- elif line.command == 'NICK':
- extra = {'newnick': line.params[0]}
- elif line.command == 'ACCOUNT':
- extra = {'account': line.params[0]}
- yield {'time': time_, 'command': line.command, 'channel': channel, 'user': get_user(channel), **extra}
- elif line.command == 'MODE' and line.params[0][0] in ('#', '&'):
- channel = line.params[0]
- yield {'time': time_, 'command': 'MODE', 'channel': channel, 'user': get_user(channel), 'args': line.params[1:]}
- elif line.command == 'KICK':
- channel = line.params[0]
- targetUser = self.server.users[self.server.casefold(line.params[1])]
- yield {
- 'time': time_,
- 'command': 'KICK',
- 'channel': channel,
- 'user': get_user(channel),
- 'targetuser': {'nick': targetUser.nickname, 'hostmask': targetUser.hostmask(), 'modes': get_modes(channel, targetUser.nickname), 'account': targetUser.account},
- 'reason': line.params[2] if len(line.params) == 3 else None
- }
- elif line.command == 'TOPIC':
- channel = line.params[0]
- channelObj = self.server.channels[self.server.casefold(channel)]
- oldTopic = {'topic': channelObj.topic, 'setter': channelObj.topic_setter, 'time': channelObj.topic_time.timestamp() if channelObj.topic_time else None} if channelObj.topic else None
- if line.params[1] == '':
- yield {'time': time_, 'command': 'TOPIC', 'channel': channel, 'user': get_user(channel), 'oldtopic': oldTopic, 'newtopic': None}
- else:
- yield {'time': time_, 'command': 'TOPIC', 'channel': channel, 'user': get_user(channel), 'oldtopic': oldTopic, 'newtopic': line.params[1]}
- elif line.command == ircstates.numerics.RPL_TOPIC:
- channel = line.params[1]
- yield {'time': time_, 'command': 'RPL_TOPIC', 'channel': channel, 'topic': line.params[2]}
- elif line.command == ircstates.numerics.RPL_TOPICWHOTIME:
- yield {'time': time_, 'command': 'RPL_TOPICWHOTIME', 'channel': line.params[1], 'setter': {'nick': irctokens.hostmask(line.params[2]).nickname, 'hostmask': line.params[2]}, 'topictime': int(line.params[3])}
- elif line.command == ircstates.numerics.RPL_ENDOFNAMES:
- channel = line.params[1]
- users = self.server.channels[self.server.casefold(channel)].users
- yield {'time': time_, 'command': 'NAMES', 'channel': channel, 'users': [{'nick': u.nickname, 'modes': self.get_mode_chars(u)} for u in users.values()]}
-
- async def quit(self):
- # The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event.
- self.logger.info('Quitting')
- self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue
- self._direct_send(b'QUIT :Bye')
- await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10)
- if not self.connectionClosedEvent.is_set():
- self.logger.error('Quitting cleanly did not work, closing connection forcefully')
- # Event will be set implicitly in connection_lost.
- self.transport.close()
-
- def connection_lost(self, exc):
- time_ = time.time()
- self.logger.info('IRC connection lost')
- self.connected = False
- self.connectionClosedEvent.set()
- self.irc2httpBroadcaster.send(Broadcaster.ALL_CHANNELS, {'time': time_, 'command': 'CONNLOST'})
-
-
- class IRCClient:
- logger = logging.getLogger('http2irc.IRCClient')
-
- def __init__(self, http2ircMessageQueue, irc2httpBroadcaster, config):
- self.http2ircMessageQueue = http2ircMessageQueue
- self.irc2httpBroadcaster = irc2httpBroadcaster
- self.config = config
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
-
- self._transport = None
- self._protocol = None
-
- def update_config(self, config):
- needReconnect = self.config['irc'] != config['irc']
- self.config = config
- if self._transport: # if currently connected:
- if needReconnect:
- self._transport.close()
- else:
- self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
- self._protocol.update_channels(self.channels)
-
- def _get_ssl_context(self):
- ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
- if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
- if ctx is True:
- ctx = ssl.create_default_context()
- if isinstance(ctx, ssl.SSLContext):
- ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
- return ctx
-
- async def run(self, loop, sigintEvent):
- connectionClosedEvent = asyncio.Event()
- while True:
- connectionClosedEvent.clear()
- try:
- self.logger.debug('Creating IRC connection')
- t = asyncio.create_task(loop.create_connection(
- protocol_factory = lambda: IRCClientProtocol(self.http2ircMessageQueue, self.irc2httpBroadcaster, connectionClosedEvent, loop, self.config, self.channels),
- host = self.config['irc']['host'],
- port = self.config['irc']['port'],
- ssl = self._get_ssl_context(),
- family = self.config['irc']['family'],
- ))
- # No automatic cancellation of t because it's handled manually below.
- done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30)
- if t not in done:
- t.cancel()
- await t # Raises the CancelledError
- self._transport, self._protocol = t.result()
- self.logger.debug('Starting send queue processing')
- sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent
- self.logger.debug('Waiting for connection closure or SIGINT')
- try:
- await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
- finally:
- self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}')
- if not connectionClosedEvent.is_set():
- self.logger.debug('Quitting connection')
- await self._protocol.quit()
- if not sendTask.done():
- sendTask.cancel()
- try:
- await sendTask
- except asyncio.CancelledError:
- pass
- self._transport = None
- self._protocol = None
- except (ConnectionError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e:
- self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}')
- await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5)
- if sigintEvent.is_set():
- self.logger.debug('Got SIGINT, breaking IRC loop')
- break
-
- @property
- def lastRecvTime(self):
- return self._protocol.lastRecvTime if self._protocol else None
-
-
- class Broadcaster:
- ALL_CHANNELS = object() # Singleton for send's channel argument, e.g. for connection loss
-
- def __init__(self):
- self._queues = {}
-
- def subscribe(self, channel):
- queue = asyncio.Queue()
- if channel not in self._queues:
- self._queues[channel] = set()
- self._queues[channel].add(queue)
- return queue
-
- def _send(self, channel, j):
- for queue in self._queues[channel]:
- queue.put_nowait(j)
-
- def send(self, channel, d):
- if channel is self.ALL_CHANNELS and self._queues:
- channels = self._queues
- elif channel in self._queues:
- channels = [channel]
- else:
- return
- j = json.dumps(d, separators = (',', ':')).encode('utf-8')
- for channel in channels:
- self._send(channel, j)
-
- def unsubscribe(self, channel, queue):
- self._queues[channel].remove(queue)
- if not self._queues[channel]:
- del self._queues[channel]
-
-
- class WebServer:
- logger = logging.getLogger('http2irc.WebServer')
-
- def __init__(self, http2ircMessageQueue, irc2httpBroadcaster, ircClient, config):
- self.http2ircMessageQueue = http2ircMessageQueue
- self.irc2httpBroadcaster = irc2httpBroadcaster
- self.ircClient = ircClient
- self.config = config
-
- self._paths = {}
- # '/path' => ('#channel', postauth, getauth, module, moduleargs, overlongmode)
- # {post,get}auth are either False (access denied) or the HTTP header value for basic auth
-
- self._routes = [
- aiohttp.web.get('/status', self.get_status),
- aiohttp.web.post('/{path:.+}', functools.partial(self._path_request, func = self.post)),
- aiohttp.web.get('/{path:.+}', functools.partial(self._path_request, func = self.get)),
- ]
-
- self.update_config(config)
- self._configChanged = asyncio.Event()
- self.stopEvent = None
-
- def update_config(self, config):
- self._paths = {map_['webpath']: (
- map_['ircchannel'],
- f'Basic {base64.b64encode(map_["postauth"].encode("utf-8")).decode("utf-8")}' if map_['postauth'] else False,
- f'Basic {base64.b64encode(map_["getauth"].encode("utf-8")).decode("utf-8")}' if map_['getauth'] else False,
- map_['module'],
- map_['moduleargs'],
- map_['overlongmode']
- ) for map_ in config['maps'].values()}
- needRebind = self.config['web'] != config['web']
- self.config = config
- if needRebind:
- self._configChanged.set()
-
- async def run(self, stopEvent):
- self.stopEvent = stopEvent
- while True:
- app = aiohttp.web.Application(client_max_size = self.config['web']['maxrequestsize'])
- app.add_routes(self._routes)
- runner = aiohttp.web.AppRunner(app)
- await runner.setup()
- site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
- await site.start()
- await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED)
- await runner.cleanup()
- if stopEvent.is_set():
- break
- self._configChanged.clear()
-
- async def get_status(self, request):
- self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.path!r}')
- return (aiohttp.web.Response if (self.ircClient.lastRecvTime or 0) > time.time() - 600 else aiohttp.web.HTTPInternalServerError)()
-
- async def _path_request(self, request, func):
- self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.method} {request.path!r} with body {(await request.read())!r}')
- try:
- pathConfig = self._paths[request.path]
- except KeyError:
- self.logger.info(f'Bad request {id(request)}: no path {request.path!r}')
- raise aiohttp.web.HTTPNotFound()
- auth = pathConfig[1] if request.method == 'POST' else pathConfig[2]
- authHeader = request.headers.get('Authorization')
- if not authHeader or not auth or authHeader != auth:
- self.logger.info(f'Bad request {id(request)}: authentication failed: {authHeader!r} != {auth}')
- raise aiohttp.web.HTTPForbidden()
- return (await func(request, *pathConfig))
-
- async def post(self, request, channel, postauth, getauth, module, moduleargs, overlongmode):
- if module is not None:
- self.logger.debug(f'Processing request {id(request)} using {module!r}')
- try:
- message = await module.process(request, *moduleargs)
- except aiohttp.web.HTTPException as e:
- raise e
- except Exception as e:
- self.logger.error(f'Bad request {id(request)}: exception in module process function: {type(e).__module__}.{type(e).__name__}: {e!s}')
- raise aiohttp.web.HTTPBadRequest()
- if message is None:
- self.logger.info(f'Accepted request {id(request)}, module returned None')
- raise aiohttp.web.HTTPOk()
- if '\r' in message or '\n' in message:
- self.logger.error(f'Bad request {id(request)}: module process function returned message with linebreaks: {message!r}')
- raise aiohttp.web.HTTPBadRequest()
- else:
- self.logger.debug(f'Processing request {id(request)} using default processor')
- message = await self._default_process(request)
- self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
- self.http2ircMessageQueue.put_nowait((channel, message, overlongmode))
- raise aiohttp.web.HTTPOk()
-
- async def _default_process(self, request):
- try:
- message = await request.text()
- except Exception as e:
- self.logger.info(f'Bad request {id(request)}: exception while reading request data: {e!s}')
- raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-)
- self.logger.debug(f'Request {id(request)} payload: {message!r}')
- # Strip optional [CR] LF at the end of the payload
- if message.endswith('\r\n'):
- message = message[:-2]
- elif message.endswith('\n'):
- message = message[:-1]
- if '\r' in message or '\n' in message:
- self.logger.info(f'Bad request {id(request)}: linebreaks in message')
- raise aiohttp.web.HTTPBadRequest()
- return message
-
- async def get(self, request, channel, postauth, getauth, module, moduleargs, overlongmode):
- self.logger.info(f'Subscribing listener from request {id(request)} for {channel}')
- queue = self.irc2httpBroadcaster.subscribe(channel)
- response = aiohttp.web.StreamResponse()
- response.enable_chunked_encoding()
- await response.prepare(request)
- try:
- while True:
- t = asyncio.create_task(queue.get())
- done, pending = await wait_cancel_pending({t, asyncio.create_task(self.stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED)
- if t not in done: # stopEvent or config change
- #TODO Don't break if the config change doesn't affect this connection
- break
- j = t.result()
- await response.write(j + b'\n')
- finally:
- self.irc2httpBroadcaster.unsubscribe(channel, queue)
- self.logger.info(f'Unsubscribed listener from request {id(request)} for {channel}')
- return response
-
-
- def configure_logging(config):
- #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+)
- root = logging.getLogger()
- root.setLevel(getattr(logging, config['logging']['level']))
- root.handlers = [] #FIXME: Undocumented attribute of logging.Logger
- formatter = logging.Formatter(config['logging']['format'], style = '{')
- stderrHandler = logging.StreamHandler()
- stderrHandler.setFormatter(formatter)
- root.addHandler(stderrHandler)
-
-
- async def main():
- if len(sys.argv) != 2:
- print('Usage: http2irc.py CONFIGFILE', file = sys.stderr)
- sys.exit(1)
- configFile = sys.argv[1]
- config = Config(configFile)
- configure_logging(config)
-
- loop = asyncio.get_running_loop()
-
- http2ircMessageQueue = MessageQueue()
- irc2httpBroadcaster = Broadcaster()
-
- irc = IRCClient(http2ircMessageQueue, irc2httpBroadcaster, config)
- webserver = WebServer(http2ircMessageQueue, irc2httpBroadcaster, irc, config)
-
- sigintEvent = asyncio.Event()
- def sigint_callback():
- global logger
- nonlocal sigintEvent
- logger.info('Got SIGINT, stopping')
- sigintEvent.set()
- loop.add_signal_handler(signal.SIGINT, sigint_callback)
-
- def sigusr1_callback():
- global logger
- nonlocal config, irc, webserver
- logger.info('Got SIGUSR1, reloading config')
- try:
- newConfig = config.reread()
- except InvalidConfig as e:
- logger.error(f'Config reload failed: {e!s} (old config remains active)')
- return
- config = newConfig
- configure_logging(config)
- irc.update_config(config)
- webserver.update_config(config)
- loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)
-
- await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent))
-
-
- if __name__ == '__main__':
- asyncio.run(main())
|