|
- import aiohttp
- import aiohttp.web
- import asyncio
- import base64
- import collections
- import importlib.util
- import inspect
- import ircstates
- import irctokens
- import itertools
- import logging
- import os.path
- import signal
- import ssl
- import string
- import sys
- import tempfile
- import time
- import toml
-
-
- logger = logging.getLogger('irclog')
- SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
- messageConnectionClosed = object() # Signals that the connection was closed by either the bot or the server
- messageEOF = object() # Special object to signal the end of messages to Storage
-
-
- class InvalidConfig(Exception):
- '''Error in configuration file'''
-
-
- def is_valid_pem(path, withCert):
- '''Very basic check whether something looks like a valid PEM certificate'''
- try:
- with open(path, 'rb') as fp:
- contents = fp.read()
-
- # All of these raise exceptions if something's wrong...
- if withCert:
- assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n')
- endCertPos = contents.index(b'-----END CERTIFICATE-----\n')
- base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True)
- assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n')
- else:
- assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n')
- endCertPos = -26 # Please shoot me.
- endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n')
- base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True)
- assert contents[endKeyPos + 26:] == b''
- return True
- except: # Yes, really
- return False
-
-
- class Config(dict):
- def __init__(self, filename):
- super().__init__()
- self._filename = filename
-
- with open(self._filename, 'r') as fp:
- obj = toml.load(fp)
-
- # Sanity checks
- if any(x not in ('logging', 'storage', 'irc', 'web', 'channels') for x in obj.keys()):
- raise InvalidConfig('Unknown sections found in base object')
- if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
- raise InvalidConfig('Invalid section type(s), expected objects/dicts')
- if 'logging' in obj:
- if any(x not in ('level', 'format') for x in obj['logging']):
- raise InvalidConfig('Unknown key found in log section')
- if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'):
- raise InvalidConfig('Invalid log level')
- if 'format' in obj['logging']:
- if not isinstance(obj['logging']['format'], str):
- raise InvalidConfig('Invalid log format')
- try:
- #TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion)
- # This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text).
- assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0
- except (ValueError, AssertionError) as e:
- raise InvalidConfig('Invalid log format: parsing failed') from e
- if 'storage' in obj:
- if any(x != 'path' for x in obj['storage']):
- raise InvalidConfig('Unknown key found in storage section')
- if 'path' in obj['storage']:
- obj['storage']['path'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['storage']['path']))
- try:
- #TODO This doesn't seem to work correctly; doesn't fail when the dir is -w
- f = tempfile.TemporaryFile(dir = obj['storage']['path'])
- f.close()
- except (OSError, IOError) as e:
- raise InvalidConfig('Invalid storage path: not writable') from e
- if 'irc' in obj:
- if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
- raise InvalidConfig('Unknown key found in irc section')
- if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
- raise InvalidConfig('Invalid IRC host')
- if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
- raise InvalidConfig('Invalid IRC port')
- if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
- raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
- if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname, username, etc.
- raise InvalidConfig('Invalid IRC nick')
- if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
- raise InvalidConfig('Invalid IRC nick: NICK command too long')
- if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
- raise InvalidConfig('Invalid IRC realname')
- if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510:
- raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long')
- if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']):
- raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile')
- if 'certfile' in obj['irc']:
- if not isinstance(obj['irc']['certfile'], str):
- raise InvalidConfig('Invalid certificate file: not a string')
- obj['irc']['certfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certfile']))
- if not os.path.isfile(obj['irc']['certfile']):
- raise InvalidConfig('Invalid certificate file: not a regular file')
- if not is_valid_pem(obj['irc']['certfile'], True):
- raise InvalidConfig('Invalid certificate file: not a valid PEM cert')
- if 'certkeyfile' in obj['irc']:
- if not isinstance(obj['irc']['certkeyfile'], str):
- raise InvalidConfig('Invalid certificate key file: not a string')
- obj['irc']['certkeyfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certkeyfile']))
- if not os.path.isfile(obj['irc']['certkeyfile']):
- raise InvalidConfig('Invalid certificate key file: not a regular file')
- if not is_valid_pem(obj['irc']['certkeyfile'], False):
- raise InvalidConfig('Invalid certificate key file: not a valid PEM key')
- if 'web' in obj:
- if any(x not in ('host', 'port') for x in obj['web']):
- raise InvalidConfig('Unknown key found in web section')
- if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
- raise InvalidConfig('Invalid web hostname')
- if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
- raise InvalidConfig('Invalid web port')
- if 'channels' in obj:
- seenChannels = {}
- for key, channel in obj['channels'].items():
- if not isinstance(key, str) or not key:
- raise InvalidConfig(f'Invalid channel key {key!r}')
- if not isinstance(channel, collections.abc.Mapping):
- raise InvalidConfig(f'Invalid channel for {key!r}')
- if any(x not in ('ircchannel', 'auth', 'active') for x in channel):
- raise InvalidConfig(f'Unknown key(s) found in channel {key!r}')
-
- if 'ircchannel' not in channel:
- channel['ircchannel'] = f'#{key}'
- if not isinstance(channel['ircchannel'], str):
- raise InvalidConfig(f'Invalid channel {key!r} IRC channel: not a string')
- if not channel['ircchannel'].startswith('#') and not channel['ircchannel'].startswith('&'):
- raise InvalidConfig(f'Invalid channel {key!r} IRC channel: does not start with # or &')
- if any(x in channel['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')):
- raise InvalidConfig(f'Invalid channel {key!r} IRC channel: contains forbidden characters')
- if len(channel['ircchannel']) > 200:
- raise InvalidConfig(f'Invalid channel {key!r} IRC channel: too long')
- if channel['ircchannel'] in seenChannels:
- raise InvalidConfig(f'Invalid channel {key!r} IRC channel: collides with channel {seenWebPaths[channel["ircchannel"]]!r}')
- seenChannels[channel['ircchannel']] = key
-
- if 'auth' in channel:
- if channel['auth'] is not False and not isinstance(channel['auth'], str):
- raise InvalidConfig(f'Invalid channel {key!r} auth: must be false or a string')
- if isinstance(channel['auth'], str) and ':' not in channel['auth']:
- raise InvalidConfig(f'Invalid channel {key!r} auth: must contain a colon')
- else:
- channel['auth'] = False
-
- if 'active' in channel:
- if channel['active'] is not True and channel['active'] is not False:
- raise InvalidConfig(f'Invalid channel {key!r} active: must be true or false')
- else:
- channel['active'] = True
-
- # Default values
- finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'storage': {'path': os.path.abspath(os.path.dirname(self._filename))}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'irclogbot', 'real': 'I am an irclog bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'channels': {}}
- # Default values for channels are already set above.
-
- # Merge in what was read from the config file and set keys on self
- for key in ('logging', 'storage', 'irc', 'web', 'channels'):
- if key in obj:
- finalObj[key].update(obj[key])
- self[key] = finalObj[key]
-
- def __repr__(self):
- return f'<Config(logging={self["logging"]!r}, storage={self["storage"]!r}, irc={self["irc"]!r}, web={self["web"]!r}, channels={self["channels"]!r})>'
-
- def reread(self):
- return Config(self._filename)
-
-
- class IRCClientProtocol(asyncio.Protocol):
- logger = logging.getLogger('irclog.IRCClientProtocol')
-
- def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
- self.messageQueue = messageQueue
- self.connectionClosedEvent = connectionClosedEvent
- self.loop = loop
- self.config = config
- self.buffer = b''
- self.connected = False
- self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
- self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}
- self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
- self.authenticated = False
- self.server = ircstates.Server(self.config['irc']['host'])
-
- @staticmethod
- def nick_command(nick: str):
- return b'NICK ' + nick.encode('utf-8')
-
- @staticmethod
- def user_command(nick: str, real: str):
- nickb = nick.encode('utf-8')
- return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')
-
- @staticmethod
- def valid_channel(channel: str):
- return channel[0] in ('#', '&') and not any(x in channel for x in (' ', '\x00', '\x07', '\r', '\n', ','))
-
- @staticmethod
- def valid_nick(nick: str):
- # According to RFC 1459, a nick must be '<letter> { <letter> | <number> | <special> }'. This is obviously not true in practice because <special> doesn't include underscores, for example.
- # So instead, just do a sanity check similar to the channel one to disallow obvious bullshit.
- return not any(x in nick for x in (' ', '\x00', '\x07', '\r', '\n', ','))
-
- @staticmethod
- def prefix_to_nick(prefix: str):
- nick = prefix[1:]
- if '!' in nick:
- nick = nick.split('!', 1)[0]
- if '@' in nick: # nick@host is also legal
- nick = nick.split('@', 1)[0]
- return nick
-
- def connection_made(self, transport):
- self.logger.info('IRC connected')
- self.transport = transport
- self.connected = True
- if self.sasl:
- self.send(b'CAP REQ :sasl')
- self.send(self.nick_command(self.config['irc']['nick']))
- self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))
-
- def _send_join_part(self, command, channels):
- '''Split a JOIN or PART into multiple messages as necessary'''
- # command: b'JOIN' or b'PART'; channels: set[str]
-
- channels = [x.encode('utf-8') for x in channels]
- if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma)
- # Everything fits into one command.
- self.send(command + b' ' + b','.join(channels))
- return
-
- # List too long, need to split.
- limit = 510 - len(command)
- lengths = [1 + len(x) for x in channels] # separator + channel name
- chanLengthAcceptable = [l <= limit for l in lengths]
- if not all(chanLengthAcceptable):
- # There are channel names that are too long to even fit into one message on their own; filter them out and warn about them.
- # This should never happen since the config reader would already filter it out.
- tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a]
- channels = [x for x, a in zip(channels, chanLengthAcceptable) if a]
- lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a]
- for channel in tooLongChannels:
- self.logger.warning(f'Cannot {command} {channel}: name too long')
- runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators
- offset = 0
- while channels:
- i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1)
- if i == -1: # Last batch
- i = len(channels)
- self.send(command + b' ' + b','.join(channels[:i]))
- offset = runningLengths[i-1]
- channels = channels[i:]
- runningLengths = runningLengths[i:]
-
- def update_channels(self, channels: set):
- channelsToPart = self.channels - channels
- channelsToJoin = channels - self.channels
- self.channels = channels
-
- if self.connected:
- if channelsToPart:
- self._send_join_part(b'PART', channelsToPart)
- if channelsToJoin:
- self._send_join_part(b'JOIN', channelsToJoin)
-
- def send(self, data):
- self.logger.debug(f'Send: {data!r}')
- if len(data) > 510:
- raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
- time_ = time.time()
- self.transport.write(data + b'\r\n')
- self.messageQueue.put_nowait((time_, b'> ' + data, None))
-
- def data_received(self, data):
- time_ = time.time()
- self.logger.debug(f'Data received: {data!r}')
- # Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
- # Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
- # If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
- messages = data.split(b'\r\n')
- if self.buffer:
- messages[0] = self.buffer + messages[0]
- for message in messages[:-1]:
- lines = self.server.recv(message + b'\r\n')
- assert len(lines) == 1
- self.server.parse_tokens(lines[0])
- self.message_received(time_, message, lines[0])
- self.buffer = messages[-1]
-
- def message_received(self, time_, message, line):
- self.logger.debug(f'Message received at {time_}: {message!r}')
-
- # Queue message for storage
- sendGeneral = True
- if line.command in ('QUIT', 'NICK') and line.source:
- if line.hostmask.nickname == self.server.nickname:
- # Self-quit
- sendGeneral = False
- self.messageQueue.put_nowait((time_, b'< ' + message, list(self.channels) + ['general']))
- else:
- try:
- user = self.server.users[line.hostmask.nickname]
- except KeyError:
- pass
- else:
- sendGeneral = False
- self.messageQueue.put_nowait((time_, b'< ' + message, user.channels))
- if sendGeneral:
- self.messageQueue.put_nowait((time_, b'< ' + message, None))
-
- # PING/PONG
- if line.command == 'PING':
- self.send(irctokens.build('PONG', line.params).format().encode('utf-8'))
-
- # SASL
- elif line.command == 'CAP' and self.sasl:
- if line.params[-2] == 'ACK' and 'sasl' in line.params[-1].split(' '):
- self.send(b'AUTHENTICATE EXTERNAL')
- else:
- self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection')
- self.transport.close()
- elif line.command == 'AUTHENTICATE' and line.params == ['+']:
- self.send(b'AUTHENTICATE +')
- elif line.command == '903': # SASL auth successful
- self.authenticated = True
- self.send(b'CAP END')
- elif line.command in ('902', '904', '905', '906', '908'):
- self.logger.error('SASL error, terminating connection')
- self.transport.close()
-
- # NICK errors
- elif line.command in ('431', '432', '433', '436'):
- self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
- self.transport.close()
-
- # USER errors
- elif line.command in ('461', '462'):
- self.logger.error(f'Failed to register: {message!r}, terminating connection')
- self.transport.close()
-
- # JOIN errors
- elif line.command in ('405', '471', '473', '474', '475'):
- self.logger.error(f'Failed to join channel: {message!r}, terminating connection')
- self.transport.close()
-
- # PART errors
- elif line.command == '442':
- self.logger.error(f'Failed to part channel: {message!r}')
-
- # JOIN/PART errors
- elif line.command == '403':
- self.logger.error(f'Failed to join or part channel: {message!r}')
-
- # Connection registration reply
- elif line.command == '001':
- self.logger.info('IRC connection registered')
- if self.sasl and not self.authenticated:
- self.logger.error('IRC connection registered but not authenticated, terminating connection')
- self.transport.close()
- return
- self._send_join_part(b'JOIN', self.channels)
-
- # General fatal ERROR
- elif line.command == 'ERROR':
- self.logger.error(f'Server sent ERROR: {message!r}')
- self.transport.close()
-
- async def quit(self):
- # The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event.
- self.logger.info('Quitting')
- self.send(b'QUIT :Bye')
- await self.connectionClosedEvent.wait()
- self.transport.close()
-
- def connection_lost(self, exc):
- time_ = time.time()
- self.logger.info('IRC connection lost')
- self.connected = False
- self.connectionClosedEvent.set()
- self.messageQueue.put_nowait((time_, messageConnectionClosed, list(self.channels) + ['general']))
-
-
- class IRCClient:
- logger = logging.getLogger('irclog.IRCClient')
-
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
- self.channels = {channel['ircchannel'] for channel in config['channels'].values()}
-
- self._transport = None
- self._protocol = None
-
- def update_config(self, config):
- needReconnect = self.config['irc'] != config['irc']
- self.config = config
- if self._transport: # if currently connected:
- if needReconnect:
- self._transport.close()
- else:
- self.channels = {channel['ircchannel'] for channel in config['channels'].values()}
- self._protocol.update_channels(self.channels)
-
- def _get_ssl_context(self):
- ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
- if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
- if ctx is True:
- ctx = ssl.create_default_context()
- if isinstance(ctx, ssl.SSLContext):
- ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
- return ctx
-
- async def run(self, loop, sigintEvent):
- connectionClosedEvent = asyncio.Event()
- while True:
- connectionClosedEvent.clear()
- try:
- self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
- try:
- await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED)
- finally:
- if not connectionClosedEvent.is_set():
- await self._protocol.quit()
- except (ConnectionRefusedError, ssl.SSLError, asyncio.TimeoutError) as e:
- self.logger.error(str(e))
- await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED)
- if sigintEvent.is_set():
- self.logger.debug('Got SIGINT, putting EOF and breaking')
- self.messageQueue.put_nowait(messageEOF)
- break
-
-
- class Storage:
- logger = logging.getLogger('irclog.Storage')
-
- def __init__(self, messageQueue, config):
- self.messageQueue = messageQueue
- self.config = config
- self.files = {} # channel|None -> fileobj; None = general log for anything that wasn't recognised as a message for the channel log
- self.active = True
-
- def update_config(self, config):
- channelsOld = {channel['ircchannel'] for channel in self.config['channels'].values()}
- channelsNew = {channel['ircchannel'] for channel in config['channels'].values()}
- channelsRemoved = channelsOld - channelsNew
- self.config = config
-
- for channel in channelsRemoved:
- if channel in self.files:
- self.files[channel].close()
- del self.files[channel]
-
- #TODO mkdir as required
- #TODO month
-
- for channel in self.config['channels'].values():
- if channel['ircchannel'] not in self.files and channel['active']:
- self.files[channel['ircchannel']] = open(os.path.join(self.config['storage']['path'], channel['ircchannel'], '2020-10.log'), 'ab')
-
- if None not in self.files:
- self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab')
-
- async def run(self, loop, sigintEvent):
- self.update_config(self.config) # Ensure that files are open etc.
- #TODO Task to rotate log files at the beginning of a new month
- storageTask = asyncio.create_task(self.store_messages(sigintEvent))
- flushTask = asyncio.create_task(self.flush_files())
- await sigintEvent.wait()
- self.logger.debug('Got SIGINT, waiting for remaining messages to be stored')
- await storageTask # Wait until everything's stored
- self.active = False
- self.logger.debug('Waiting for flush task')
- await flushTask
- self.close()
-
- async def store_messages(self, sigintEvent):
- while self.active:
- self.logger.debug('Waiting for message')
- res = await self.messageQueue.get()
- self.logger.debug(f'Got {res!r} from message queue')
- if res is messageEOF:
- self.logger.debug('Message EOF, breaking store_messages loop')
- break
-
- time_, rawMessage, channels = res
- if rawMessage is messageConnectionClosed:
- rawMessage = b'- Connection closed'
- message = rawMessage[2:] # Remove leading > or <
- if message.startswith(b':') and b' ' in message:
- prefix, message = message.split(b' ', 1)
-
- # Identify channel-bound messages: JOIN, PART, QUIT, MODE, KICK, PRIVMSG, NOTICE (see https://tools.ietf.org/html/rfc1459#section-4.2.1)
- if message.startswith(b'JOIN ') or message.startswith(b'PART ') or message.startswith(b'PRIVMSG ') or message.startswith(b'NOTICE '):
- # I *think* that the first parameter of JOIN/PART can only ever be a single channel for messages announcing other people joining, but who knows with how awful RFC 1459 is...
- channelsRaw = message.split(b' ', 2)[1]
- channels = self.decode_channel(time_, rawMessage, channelsRaw.split(b','))
- if channels is None:
- continue
- for channel in channels:
- self.store_message(time_, rawMessage, channel)
- continue
- if message.startswith(b'QUIT ') or message == b'QUIT' or message.startswith(b'NICK '):
- # If channels is not None, IRCClientProtocol managed to track the user and identify the channels this needs to be logged to.
- # If it isn't, there might be channels in there (for some odd reason?) that are not being logged. In that case, emit one and only one message to the general log as well.
- if channels is not None:
- for channel in channels:
- self.store_message(time_, rawMessage, channel, redirectToGeneral = False)
- if channels is None or any(channel not in self.files for channel in channels):
- self.store_message(time_, rawMessage, None)
- continue
- if message.startswith(b'MODE #') or message.startswith(b'MODE &') or message.startswith(b'KICK '):
- channel = message.split(b' ', 2)[1]
- channel = self.decode_channel(time_, rawMessage, channel)
- if channel is None:
- continue
- self.store_message(time_, rawMessage, channel)
- continue
- if channels is not None:
- for channel in channels:
- self.store_message(time_, rawMessage, channel)
- else:
- self.store_message(time_, rawMessage, None)
-
- def store_message(self, time_, rawMessage, targetChannel, redirectToGeneral = True):
- self.logger.debug(f'Logging {rawMessage!r} at {time_} for {targetChannel!r}')
- if targetChannel is not None and targetChannel not in self.files:
- self.logger.debug(f'Target channel {targetChannel!r} not opened, redirecting to general log is {redirectToGeneral}')
- if not redirectToGeneral:
- return
- targetChannel = None
- self.files[targetChannel].write(str(time_).encode('ascii') + b' ' + rawMessage + b'\r\n')
-
- def decode_channel(self, time_, rawMessage, channel):
- try:
- if isinstance(channel, list):
- return [c.decode('utf-8') for c in channel]
- return channel.decode('utf-8')
- except UnicodeDecodeError as e:
- self.logger.warning(f'Failed to decode channel name {channel!r} from {rawMessage!r} at {time_}: {e!s}')
- self.store_message(time_, rawMessage, None)
- return None
-
- async def flush_files(self):
- while self.active:
- await asyncio.sleep(1)
- self.logger.debug('Exiting flush_files')
-
- def close(self):
- for f in self.files.values():
- f.close()
- self.files = {}
-
-
- class WebServer:
- logger = logging.getLogger('irclog.WebServer')
-
- def __init__(self, config):
- self.config = config
-
- self._paths = {} # '/path' => ('#channel', auth, module, moduleargs) where auth is either False (no authentication) or the HTTP header value for basic auth
-
- self._app = aiohttp.web.Application()
- self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)])
-
- self.update_config(config)
- self._configChanged = asyncio.Event()
-
- def update_config(self, config):
- # self._paths = {channel['webpath']: (channel['ircchannel'], f'Basic {base64.b64encode(channel["auth"].encode("utf-8")).decode("utf-8")}' if channel['auth'] else False) for channel in config['channels'].values()}
- needRebind = self.config['web'] != config['web'] #TODO only if there are changes to web.host or web.port; everything else can be updated without rebinding
- self.config = config
- if needRebind:
- self._configChanged.set()
-
- async def run(self, stopEvent):
- while True:
- runner = aiohttp.web.AppRunner(self._app)
- await runner.setup()
- site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
- await site.start()
- await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = asyncio.FIRST_COMPLETED)
- await runner.cleanup()
- if stopEvent.is_set():
- break
- self._configChanged.clear()
-
- # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process
- # https://stackoverflow.com/questions/1180606/using-subprocess-popen-for-process-with-large-output
- # -> https://stackoverflow.com/questions/57730010/python-asyncio-subprocess-write-stdin-and-read-stdout-stderr-continuously
-
- async def post(self, request):
- self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.path!r} with body {(await request.read())!r}')
- try:
- channel, auth, module, moduleargs, overlongmode = self._paths[request.path]
- except KeyError:
- self.logger.info(f'Bad request {id(request)}: no path {request.path!r}')
- raise aiohttp.web.HTTPNotFound()
- if auth:
- authHeader = request.headers.get('Authorization')
- if not authHeader or authHeader != auth:
- self.logger.info(f'Bad request {id(request)}: authentication failed: {authHeader!r} != {auth}')
- raise aiohttp.web.HTTPForbidden()
- if module is not None:
- self.logger.debug(f'Processing request {id(request)} using {module!r}')
- try:
- message = await module.process(request, *moduleargs)
- except aiohttp.web.HTTPException as e:
- raise e
- except Exception as e:
- self.logger.error(f'Bad request {id(request)}: exception in module process function: {type(e).__module__}.{type(e).__name__}: {e!s}')
- raise aiohttp.web.HTTPBadRequest()
- if '\r' in message or '\n' in message:
- self.logger.error(f'Bad request {id(request)}: module process function returned message with linebreaks: {message!r}')
- raise aiohttp.web.HTTPBadRequest()
- else:
- self.logger.debug(f'Processing request {id(request)} using default processor')
- message = await self._default_process(request)
- self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
- self.messageQueue.put_nowait((channel, message, overlongmode))
- raise aiohttp.web.HTTPOk()
-
- async def _default_process(self, request):
- try:
- message = await request.text()
- except Exception as e:
- self.logger.info(f'Bad request {id(request)}: exception while reading request data: {e!s}')
- raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-)
- self.logger.debug(f'Request {id(request)} payload: {message!r}')
- # Strip optional [CR] LF at the end of the payload
- if message.endswith('\r\n'):
- message = message[:-2]
- elif message.endswith('\n'):
- message = message[:-1]
- if '\r' in message or '\n' in message:
- self.logger.info(f'Bad request {id(request)}: linebreaks in message')
- raise aiohttp.web.HTTPBadRequest()
- return message
-
-
- def configure_logging(config):
- #TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+)
- root = logging.getLogger()
- root.setLevel(getattr(logging, config['logging']['level']))
- root.handlers = [] #FIXME: Undocumented attribute of logging.Logger
- formatter = logging.Formatter(config['logging']['format'], style = '{')
- stderrHandler = logging.StreamHandler()
- stderrHandler.setFormatter(formatter)
- root.addHandler(stderrHandler)
-
-
- async def main():
- if len(sys.argv) != 2:
- print('Usage: irclog.py CONFIGFILE', file = sys.stderr)
- sys.exit(1)
- configFile = sys.argv[1]
- config = Config(configFile)
- configure_logging(config)
-
- loop = asyncio.get_running_loop()
-
- messageQueue = asyncio.Queue()
- # tuple(time: float, message: bytes or None, channels: list[str] or None)
- # message = None indicates a connection loss
- # channels = None indicates that IRCClientProtocol did not identify which channels are affected; it is a set or list of channel names for QUIT or NICK messages and the connection closed message.
-
- irc = IRCClient(messageQueue, config)
- webserver = WebServer(config)
- storage = Storage(messageQueue, config)
-
- sigintEvent = asyncio.Event()
- def sigint_callback():
- global logger
- nonlocal sigintEvent
- logger.info('Got SIGINT, stopping')
- sigintEvent.set()
- loop.add_signal_handler(signal.SIGINT, sigint_callback)
-
- def sigusr1_callback():
- global logger
- nonlocal config, irc, webserver, storage
- logger.info('Got SIGUSR1, reloading config')
- try:
- newConfig = config.reread()
- except InvalidConfig as e:
- logger.error(f'Config reload failed: {e!s} (old config remains active)')
- return
- config = newConfig
- configure_logging(config)
- irc.update_config(config)
- webserver.update_config(config)
- storage.update_config(config)
- loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)
-
- await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent), storage.run(loop, sigintEvent))
-
-
- if __name__ == '__main__':
- asyncio.run(main())
|