|
|
@@ -0,0 +1,673 @@ |
|
|
|
import aiohttp |
|
|
|
import aiohttp.web |
|
|
|
import asyncio |
|
|
|
import base64 |
|
|
|
import collections |
|
|
|
import concurrent.futures |
|
|
|
import importlib.util |
|
|
|
import inspect |
|
|
|
import itertools |
|
|
|
import logging |
|
|
|
import os.path |
|
|
|
import signal |
|
|
|
import ssl |
|
|
|
import string |
|
|
|
import sys |
|
|
|
import tempfile |
|
|
|
import time |
|
|
|
import toml |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger('irclog') |
|
|
|
SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()} |
|
|
|
|
|
|
|
|
|
|
|
class InvalidConfig(Exception): |
|
|
|
'''Error in configuration file''' |
|
|
|
|
|
|
|
|
|
|
|
def is_valid_pem(path, withCert): |
|
|
|
'''Very basic check whether something looks like a valid PEM certificate''' |
|
|
|
try: |
|
|
|
with open(path, 'rb') as fp: |
|
|
|
contents = fp.read() |
|
|
|
|
|
|
|
# All of these raise exceptions if something's wrong... |
|
|
|
if withCert: |
|
|
|
assert contents.startswith(b'-----BEGIN CERTIFICATE-----\n') |
|
|
|
endCertPos = contents.index(b'-----END CERTIFICATE-----\n') |
|
|
|
base64.b64decode(contents[28:endCertPos].replace(b'\n', b''), validate = True) |
|
|
|
assert contents[endCertPos + 26:].startswith(b'-----BEGIN PRIVATE KEY-----\n') |
|
|
|
else: |
|
|
|
assert contents.startswith(b'-----BEGIN PRIVATE KEY-----\n') |
|
|
|
endCertPos = -26 # Please shoot me. |
|
|
|
endKeyPos = contents.index(b'-----END PRIVATE KEY-----\n') |
|
|
|
base64.b64decode(contents[endCertPos + 26 + 28: endKeyPos].replace(b'\n', b''), validate = True) |
|
|
|
assert contents[endKeyPos + 26:] == b'' |
|
|
|
return True |
|
|
|
except: # Yes, really |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
class Config(dict): |
|
|
|
def __init__(self, filename): |
|
|
|
super().__init__() |
|
|
|
self._filename = filename |
|
|
|
|
|
|
|
with open(self._filename, 'r') as fp: |
|
|
|
obj = toml.load(fp) |
|
|
|
|
|
|
|
# Sanity checks |
|
|
|
if any(x not in ('logging', 'storage', 'irc', 'web', 'channels') for x in obj.keys()): |
|
|
|
raise InvalidConfig('Unknown sections found in base object') |
|
|
|
if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()): |
|
|
|
raise InvalidConfig('Invalid section type(s), expected objects/dicts') |
|
|
|
if 'logging' in obj: |
|
|
|
if any(x not in ('level', 'format') for x in obj['logging']): |
|
|
|
raise InvalidConfig('Unknown key found in log section') |
|
|
|
if 'level' in obj['logging'] and obj['logging']['level'] not in ('DEBUG', 'INFO', 'WARNING', 'ERROR'): |
|
|
|
raise InvalidConfig('Invalid log level') |
|
|
|
if 'format' in obj['logging']: |
|
|
|
if not isinstance(obj['logging']['format'], str): |
|
|
|
raise InvalidConfig('Invalid log format') |
|
|
|
try: |
|
|
|
#TODO: Replace with logging.Formatter's validate option (3.8+); this test does not cover everything that could be wrong (e.g. invalid format spec or conversion) |
|
|
|
# This counts the number of replacement fields. Formatter.parse yields tuples whose second value is the field name; if it's None, there is no field (e.g. literal text). |
|
|
|
assert sum(1 for x in string.Formatter().parse(obj['logging']['format']) if x[1] is not None) > 0 |
|
|
|
except (ValueError, AssertionError) as e: |
|
|
|
raise InvalidConfig('Invalid log format: parsing failed') from e |
|
|
|
if 'storage' in obj: |
|
|
|
if any(x != 'path' for x in obj['storage']): |
|
|
|
raise InvalidConfig('Unknown key found in storage section') |
|
|
|
if 'path' in obj['storage']: |
|
|
|
obj['storage']['path'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['storage']['path'])) |
|
|
|
try: |
|
|
|
f = tempfile.TemporaryFile(dir = obj['storage']['path']) |
|
|
|
f.close() |
|
|
|
except (OSError, IOError) as e: |
|
|
|
raise InvalidConfig('Invalid storage path: not writable') from e |
|
|
|
if 'irc' in obj: |
|
|
|
if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']): |
|
|
|
raise InvalidConfig('Unknown key found in irc section') |
|
|
|
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname |
|
|
|
raise InvalidConfig('Invalid IRC host') |
|
|
|
if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535): |
|
|
|
raise InvalidConfig('Invalid IRC port') |
|
|
|
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'): |
|
|
|
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}') |
|
|
|
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname, username, etc. |
|
|
|
raise InvalidConfig('Invalid IRC nick') |
|
|
|
if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510: |
|
|
|
raise InvalidConfig('Invalid IRC nick: NICK command too long') |
|
|
|
if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str): |
|
|
|
raise InvalidConfig('Invalid IRC realname') |
|
|
|
if len(IRCClientProtocol.user_command(obj['irc']['nick'], obj['irc']['real'])) > 510: |
|
|
|
raise InvalidConfig('Invalid IRC nick/realname combination: USER command too long') |
|
|
|
if ('certfile' in obj['irc']) != ('certkeyfile' in obj['irc']): |
|
|
|
raise InvalidConfig('Invalid IRC cert config: needs both certfile and certkeyfile') |
|
|
|
if 'certfile' in obj['irc']: |
|
|
|
if not isinstance(obj['irc']['certfile'], str): |
|
|
|
raise InvalidConfig('Invalid certificate file: not a string') |
|
|
|
obj['irc']['certfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certfile'])) |
|
|
|
if not os.path.isfile(obj['irc']['certfile']): |
|
|
|
raise InvalidConfig('Invalid certificate file: not a regular file') |
|
|
|
if not is_valid_pem(obj['irc']['certfile'], True): |
|
|
|
raise InvalidConfig('Invalid certificate file: not a valid PEM cert') |
|
|
|
if 'certkeyfile' in obj['irc']: |
|
|
|
if not isinstance(obj['irc']['certkeyfile'], str): |
|
|
|
raise InvalidConfig('Invalid certificate key file: not a string') |
|
|
|
obj['irc']['certkeyfile'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['irc']['certkeyfile'])) |
|
|
|
if not os.path.isfile(obj['irc']['certkeyfile']): |
|
|
|
raise InvalidConfig('Invalid certificate key file: not a regular file') |
|
|
|
if not is_valid_pem(obj['irc']['certkeyfile'], False): |
|
|
|
raise InvalidConfig('Invalid certificate key file: not a valid PEM key') |
|
|
|
if 'web' in obj: |
|
|
|
if any(x not in ('host', 'port') for x in obj['web']): |
|
|
|
raise InvalidConfig('Unknown key found in web section') |
|
|
|
if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?) |
|
|
|
raise InvalidConfig('Invalid web hostname') |
|
|
|
if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535): |
|
|
|
raise InvalidConfig('Invalid web port') |
|
|
|
if 'channels' in obj: |
|
|
|
seenChannels = {} |
|
|
|
for key, channel in obj['channels'].items(): |
|
|
|
if not isinstance(key, str) or not key: |
|
|
|
raise InvalidConfig(f'Invalid channel key {key!r}') |
|
|
|
if not isinstance(channel, collections.abc.Mapping): |
|
|
|
raise InvalidConfig(f'Invalid channel for {key!r}') |
|
|
|
if any(x not in ('ircchannel', 'auth', 'active') for x in channel): |
|
|
|
raise InvalidConfig(f'Unknown key(s) found in channel {key!r}') |
|
|
|
|
|
|
|
if 'ircchannel' not in channel: |
|
|
|
channel['ircchannel'] = f'#{key}' |
|
|
|
if not isinstance(channel['ircchannel'], str): |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} IRC channel: not a string') |
|
|
|
if not channel['ircchannel'].startswith('#') and not channel['ircchannel'].startswith('&'): |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} IRC channel: does not start with # or &') |
|
|
|
if any(x in channel['ircchannel'][1:] for x in (' ', '\x00', '\x07', '\r', '\n', ',')): |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} IRC channel: contains forbidden characters') |
|
|
|
if len(channel['ircchannel']) > 200: |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} IRC channel: too long') |
|
|
|
if channel['ircchannel'] in seenChannels: |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} IRC channel: collides with channel {seenWebPaths[channel["ircchannel"]]!r}') |
|
|
|
seenChannels[channel['ircchannel']] = key |
|
|
|
|
|
|
|
if 'auth' in channel: |
|
|
|
if channel['auth'] is not False and not isinstance(channel['auth'], str): |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} auth: must be false or a string') |
|
|
|
if isinstance(channel['auth'], str) and ':' not in channel['auth']: |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} auth: must contain a colon') |
|
|
|
else: |
|
|
|
channel['auth'] = False |
|
|
|
|
|
|
|
if 'active' in channel: |
|
|
|
if channel['active'] is not True and channel['active'] is not False: |
|
|
|
raise InvalidConfig(f'Invalid channel {key!r} active: must be true or false') |
|
|
|
else: |
|
|
|
channel['active'] = True |
|
|
|
|
|
|
|
# Default values |
|
|
|
finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'storage': {'path': os.path.abspath(os.path.dirname(self._filename))}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'irclogbot', 'real': 'I am an irclog bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'channels': {}} |
|
|
|
# Default values for channels are already set above. |
|
|
|
|
|
|
|
# Merge in what was read from the config file and set keys on self |
|
|
|
for key in ('logging', 'storage', 'irc', 'web', 'channels'): |
|
|
|
if key in obj: |
|
|
|
finalObj[key].update(obj[key]) |
|
|
|
self[key] = finalObj[key] |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
return f'<Config(logging={self["logging"]!r}, storage={self["storage"]!r}, irc={self["irc"]!r}, web={self["web"]!r}, channels={self["channels"]!r})>' |
|
|
|
|
|
|
|
def reread(self): |
|
|
|
return Config(self._filename) |
|
|
|
|
|
|
|
|
|
|
|
class IRCClientProtocol(asyncio.Protocol): |
|
|
|
logger = logging.getLogger('irclog.IRCClientProtocol') |
|
|
|
|
|
|
|
def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.connectionClosedEvent = connectionClosedEvent |
|
|
|
self.loop = loop |
|
|
|
self.config = config |
|
|
|
self.buffer = b'' |
|
|
|
self.connected = False |
|
|
|
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) |
|
|
|
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile']) |
|
|
|
self.authenticated = False |
|
|
|
self.usermask = None |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def nick_command(nick: str): |
|
|
|
return b'NICK ' + nick.encode('utf-8') |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def user_command(nick: str, real: str): |
|
|
|
nickb = nick.encode('utf-8') |
|
|
|
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8') |
|
|
|
|
|
|
|
def _maybe_set_usermask(self, usermask): |
|
|
|
if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')): |
|
|
|
self.usermask = usermask |
|
|
|
self.logger.debug(f'Usermask is now {usermask!r}') |
|
|
|
|
|
|
|
def connection_made(self, transport): |
|
|
|
self.logger.info('IRC connected') |
|
|
|
self.transport = transport |
|
|
|
self.connected = True |
|
|
|
if self.sasl: |
|
|
|
self.send(b'CAP REQ :sasl') |
|
|
|
self.send(self.nick_command(self.config['irc']['nick'])) |
|
|
|
self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real'])) |
|
|
|
|
|
|
|
def _send_join_part(self, command, channels): |
|
|
|
'''Split a JOIN or PART into multiple messages as necessary''' |
|
|
|
# command: b'JOIN' or b'PART'; channels: set[str] |
|
|
|
|
|
|
|
channels = [x.encode('utf-8') for x in channels] |
|
|
|
if len(command) + sum(1 + len(x) for x in channels) <= 510: # Total length = command + (separator + channel name for each channel, where the separator is a space for the first and then a comma) |
|
|
|
# Everything fits into one command. |
|
|
|
self.send(command + b' ' + b','.join(channels)) |
|
|
|
return |
|
|
|
|
|
|
|
# List too long, need to split. |
|
|
|
limit = 510 - len(command) |
|
|
|
lengths = [1 + len(x) for x in channels] # separator + channel name |
|
|
|
chanLengthAcceptable = [l <= limit for l in lengths] |
|
|
|
if not all(chanLengthAcceptable): |
|
|
|
# There are channel names that are too long to even fit into one message on their own; filter them out and warn about them. |
|
|
|
# This should never happen since the config reader would already filter it out. |
|
|
|
tooLongChannels = [x for x, a in zip(channels, chanLengthAcceptable) if not a] |
|
|
|
channels = [x for x, a in zip(channels, chanLengthAcceptable) if a] |
|
|
|
lengths = [l for l, a in zip(lengths, chanLengthAcceptable) if a] |
|
|
|
for channel in tooLongChannels: |
|
|
|
self.logger.warning(f'Cannot {command} {channel}: name too long') |
|
|
|
runningLengths = list(itertools.accumulate(lengths)) # entry N = length of all entries up to and including channel N, including separators |
|
|
|
offset = 0 |
|
|
|
while channels: |
|
|
|
i = next((x[0] for x in enumerate(runningLengths) if x[1] - offset > limit), -1) |
|
|
|
if i == -1: # Last batch |
|
|
|
i = len(channels) |
|
|
|
self.send(command + b' ' + b','.join(channels[:i])) |
|
|
|
offset = runningLengths[i-1] |
|
|
|
channels = channels[i:] |
|
|
|
runningLengths = runningLengths[i:] |
|
|
|
|
|
|
|
def update_channels(self, channels: set): |
|
|
|
channelsToPart = self.channels - channels |
|
|
|
channelsToJoin = channels - self.channels |
|
|
|
self.channels = channels |
|
|
|
|
|
|
|
if self.connected: |
|
|
|
if channelsToPart: |
|
|
|
self._send_join_part(b'PART', channelsToPart) |
|
|
|
if channelsToJoin: |
|
|
|
self._send_join_part(b'JOIN', channelsToJoin) |
|
|
|
|
|
|
|
def send(self, data): |
|
|
|
self.logger.debug(f'Send: {data!r}') |
|
|
|
if len(data) > 510: |
|
|
|
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') |
|
|
|
time_ = time.time() |
|
|
|
self.transport.write(data + b'\r\n') |
|
|
|
self.messageQueue.put_nowait((time_, b'> ' + data)) |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
self.logger.debug(f'Data received: {data!r}') |
|
|
|
time_ = time.time() |
|
|
|
# Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. |
|
|
|
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. |
|
|
|
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again. |
|
|
|
messages = data.split(b'\r\n') |
|
|
|
if self.buffer: |
|
|
|
self.message_received(time_, self.buffer + messages[0]) |
|
|
|
messages = messages[1:] |
|
|
|
for message in messages[:-1]: |
|
|
|
self.message_received(time_, message) |
|
|
|
self.buffer = messages[-1] |
|
|
|
|
|
|
|
def message_received(self, time_, message): |
|
|
|
self.logger.debug(f'Message received at {time_}: {message!r}') |
|
|
|
rawMessage = message |
|
|
|
if message.startswith(b':') and b' ' in message: |
|
|
|
# Prefixed message, extract command + parameters (the prefix cannot contain a space) |
|
|
|
message = message.split(b' ', 1)[1] |
|
|
|
|
|
|
|
# Queue message for storage |
|
|
|
self.messageQueue.put_nowait((time_, b'< ' + rawMessage)) |
|
|
|
|
|
|
|
# PING/PONG |
|
|
|
if message.startswith(b'PING '): |
|
|
|
self.send(b'PONG ' + message[5:]) |
|
|
|
|
|
|
|
# SASL |
|
|
|
elif message.startswith(b'CAP ') and self.sasl: |
|
|
|
if message[message.find(b' ', 4) + 1:] == b'ACK :sasl': |
|
|
|
self.send(b'AUTHENTICATE EXTERNAL') |
|
|
|
else: |
|
|
|
self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
elif message == b'AUTHENTICATE +': |
|
|
|
self.send(b'AUTHENTICATE +') |
|
|
|
elif message.startswith(b'900 '): # "You are now logged in", includes the usermask |
|
|
|
words = message.split(b' ') |
|
|
|
if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]: |
|
|
|
if b'!~' not in words[2]: |
|
|
|
# At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde. |
|
|
|
words[2] = words[2].replace(b'!', b'!~', 1) |
|
|
|
self._maybe_set_usermask(words[2]) |
|
|
|
elif message.startswith(b'903 '): # SASL auth successful |
|
|
|
self.authenticated = True |
|
|
|
self.send(b'CAP END') |
|
|
|
elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')): |
|
|
|
self.logger.error('SASL error, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
|
|
|
|
# NICK errors |
|
|
|
elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ')): |
|
|
|
self.logger.error(f'Failed to set nickname: {message!r}, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
|
|
|
|
# USER errors |
|
|
|
elif any(message.startswith(x) for x in (b'461 ', b'462 ')): |
|
|
|
self.logger.error(f'Failed to register: {message!r}, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
|
|
|
|
# JOIN errors |
|
|
|
elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')): |
|
|
|
self.logger.error(f'Failed to join channel: {message!r}, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
|
|
|
|
# PART errors |
|
|
|
elif message.startswith(b'442 '): |
|
|
|
self.logger.error(f'Failed to part channel: {message!r}') |
|
|
|
|
|
|
|
# JOIN/PART errors |
|
|
|
elif message.startswith(b'403 '): |
|
|
|
self.logger.error(f'Failed to join or part channel: {message!r}') |
|
|
|
|
|
|
|
# PRIVMSG errors |
|
|
|
elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ')): |
|
|
|
self.logger.error(f'Failed to send message: {message!r}') |
|
|
|
|
|
|
|
# Connection registration reply |
|
|
|
elif message.startswith(b'001 '): |
|
|
|
self.logger.info('IRC connection registered') |
|
|
|
if self.sasl and not self.authenticated: |
|
|
|
self.logger.error('IRC connection registered but not authenticated, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
return |
|
|
|
self._send_join_part(b'JOIN', self.channels) |
|
|
|
|
|
|
|
# JOIN success |
|
|
|
elif message.startswith(b'JOIN ') and not self.usermask: |
|
|
|
# If this is my own join message, it should contain the usermask in the prefix |
|
|
|
if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage: |
|
|
|
usermask = rawMessage.split(b' ', 1)[0][1:] |
|
|
|
self._maybe_set_usermask(usermask) |
|
|
|
|
|
|
|
# Services host change |
|
|
|
elif message.startswith(b'396 '): |
|
|
|
words = message.split(b' ') |
|
|
|
if len(words) >= 3: |
|
|
|
# Sanity check inspired by irssi src/irc/core/irc-servers.c |
|
|
|
if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-': |
|
|
|
if b'@' in words[2]: # user@host |
|
|
|
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2]) |
|
|
|
else: # host (get user from previous mask or settings) |
|
|
|
if self.usermask: |
|
|
|
user = self.usermask.split(b'@')[0].split(b'!')[1] |
|
|
|
else: |
|
|
|
user = b'~' + self.config['irc']['nick'].encode('utf-8') |
|
|
|
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2]) |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
self.logger.info('IRC connection lost') |
|
|
|
self.connected = False |
|
|
|
self.connectionClosedEvent.set() |
|
|
|
|
|
|
|
|
|
|
|
class IRCClient: |
|
|
|
logger = logging.getLogger('irclog.IRCClient') |
|
|
|
|
|
|
|
def __init__(self, messageQueue, config): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.config = config |
|
|
|
self.channels = {channel['ircchannel'] for channel in config['channels'].values()} |
|
|
|
|
|
|
|
self._transport = None |
|
|
|
self._protocol = None |
|
|
|
|
|
|
|
def update_config(self, config): |
|
|
|
needReconnect = self.config['irc'] != config['irc'] |
|
|
|
self.config = config |
|
|
|
if self._transport: # if currently connected: |
|
|
|
if needReconnect: |
|
|
|
self._transport.close() |
|
|
|
else: |
|
|
|
self.channels = {channel['ircchannel'] for channel in config['channels'].values()} |
|
|
|
self._protocol.update_channels(self.channels) |
|
|
|
|
|
|
|
def _get_ssl_context(self): |
|
|
|
ctx = SSL_CONTEXTS[self.config['irc']['ssl']] |
|
|
|
if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']: |
|
|
|
if ctx is True: |
|
|
|
ctx = ssl.create_default_context() |
|
|
|
if isinstance(ctx, ssl.SSLContext): |
|
|
|
ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile']) |
|
|
|
return ctx |
|
|
|
|
|
|
|
async def run(self, loop, sigintEvent): |
|
|
|
connectionClosedEvent = asyncio.Event() |
|
|
|
while True: |
|
|
|
connectionClosedEvent.clear() |
|
|
|
try: |
|
|
|
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) |
|
|
|
try: |
|
|
|
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
finally: |
|
|
|
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C? |
|
|
|
except (ConnectionRefusedError, asyncio.TimeoutError) as e: |
|
|
|
self.logger.error(str(e)) |
|
|
|
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
if sigintEvent.is_set(): |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
class Storage: |
|
|
|
logger = logging.getLogger('irclog.Storage') |
|
|
|
|
|
|
|
def __init__(self, messageQueue, config): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.config = config |
|
|
|
self.files = {} # channel|None -> fileobj; None = general log for anything that wasn't recognised as a message for the channel log |
|
|
|
self.active = True |
|
|
|
|
|
|
|
def update_config(self, config): |
|
|
|
channelsOld = {channel['ircchannel'] for channel in self.config['channels'].values()} |
|
|
|
channelsNew = {channel['ircchannel'] for channel in config['channels'].values()} |
|
|
|
channelsRemoved = channelsOld - channelsNew |
|
|
|
self.config = config |
|
|
|
|
|
|
|
for channel in channelsRemoved: |
|
|
|
if channel in self.files: |
|
|
|
self.files[channel].close() |
|
|
|
del self.files[channel] |
|
|
|
|
|
|
|
#TODO mkdir as required |
|
|
|
|
|
|
|
for channel in self.config['channels'].values(): |
|
|
|
if channel['ircchannel'] not in self.files and channel['active']: |
|
|
|
self.files[channel['ircchannel']] = open(os.path.join(self.config['storage']['path'], channel['ircchannel'], '2020-10.log'), 'ab') |
|
|
|
|
|
|
|
if None not in self.files: |
|
|
|
self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab') #TODO Month |
|
|
|
|
|
|
|
async def run(self, loop, sigintEvent): |
|
|
|
self.update_config(self.config) # Ensure that files are open etc. |
|
|
|
#TODO Task to rotate log files at the beginning of a new month |
|
|
|
storageTask = asyncio.create_task(self.store_messages(sigintEvent)) |
|
|
|
flushTask = asyncio.create_task(self.flush_files(sigintEvent)) |
|
|
|
await sigintEvent.wait() |
|
|
|
self.active = False |
|
|
|
#TODO Wait for tasks |
|
|
|
self.close() |
|
|
|
|
|
|
|
async def store_messages(self, sigintEvent): |
|
|
|
while self.active: |
|
|
|
#TODO wait for sigint as well |
|
|
|
time_, rawMessage = await self.messageQueue.get() |
|
|
|
message = rawMessage[2:] # Remove leading > or < |
|
|
|
if message.startswith(b':') and b' ' in message: |
|
|
|
prefix, message = message.split(b' ', 1) |
|
|
|
|
|
|
|
# Identify channel-bound messages: JOIN, PART, QUIT, MODE, KICK, PRIVMSG, NOTICE (see https://tools.ietf.org/html/rfc1459#section-4.2.1) |
|
|
|
if message.startswith(b'JOIN ') or message.startswith(b'PART ') or message.startswith(b'PRIVMSG ') or message.startswith(b'NOTICE '): |
|
|
|
# I *think* that the first parameter of JOIN/PART can only ever be a single channel for messages announcing other people joining, but who knows with how awful RFC 1459 is... |
|
|
|
channelsRaw = message.split(b' ', 2)[1] |
|
|
|
channels = self.decode_channel(time_, rawMessage, channelsRaw.split(b',')) |
|
|
|
if channels is None: |
|
|
|
continue |
|
|
|
for channel in channels: |
|
|
|
self.store_message(time_, rawMessage, channel) |
|
|
|
continue |
|
|
|
if message.startswith(b'QUIT '): |
|
|
|
#TODO Need to keep track of users to figure out in which channels they were... Ugh |
|
|
|
pass |
|
|
|
if message.startswith(b'MODE #') or message.startswith(b'MODE &') or message.startswith(b'KICK '): |
|
|
|
channel = message.split(b' ', 2)[1] |
|
|
|
channel = self.decode_channel(time_, rawMessage, channel) |
|
|
|
if channel is None: |
|
|
|
continue |
|
|
|
self.store_message(time_, rawMessage, channel) |
|
|
|
continue |
|
|
|
self.store_message(time_, rawMessage, None) |
|
|
|
|
|
|
|
def store_message(self, time_, rawMessage, targetChannel): |
|
|
|
if targetChannel is not None and targetChannel not in self.files: |
|
|
|
targetChannel = None |
|
|
|
self.files[targetChannel].write(str(time_).encode('ascii') + b' ' + rawMessage + b'\r\n') |
|
|
|
|
|
|
|
def decode_channel(self, time_, rawMessage, channel): |
|
|
|
try: |
|
|
|
if isinstance(channel, list): |
|
|
|
return [c.decode('utf-8') for c in channel] |
|
|
|
return channel.decode('utf-8') |
|
|
|
except UnicodeDecodeError as e: |
|
|
|
self.logger.warning(f'Failed to decode channel name {channel!r} from {rawMessage!r} at {time_}: {e!s}') |
|
|
|
self.store_message(time_, rawMessage, None) |
|
|
|
return None |
|
|
|
|
|
|
|
async def flush_files(self, sigintEvent): |
|
|
|
while self.active: |
|
|
|
await sigintEvent.wait() |
|
|
|
|
|
|
|
def close(self): |
|
|
|
for f in self.files.values(): |
|
|
|
f.close() |
|
|
|
self.files = {} |
|
|
|
|
|
|
|
|
|
|
|
class WebServer: |
|
|
|
logger = logging.getLogger('irclog.WebServer') |
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
|
self.config = config |
|
|
|
|
|
|
|
self._paths = {} # '/path' => ('#channel', auth, module, moduleargs) where auth is either False (no authentication) or the HTTP header value for basic auth |
|
|
|
|
|
|
|
self._app = aiohttp.web.Application() |
|
|
|
self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)]) |
|
|
|
|
|
|
|
self.update_config(config) |
|
|
|
self._configChanged = asyncio.Event() |
|
|
|
|
|
|
|
def update_config(self, config): |
|
|
|
# self._paths = {channel['webpath']: (channel['ircchannel'], f'Basic {base64.b64encode(channel["auth"].encode("utf-8")).decode("utf-8")}' if channel['auth'] else False) for channel in config['channels'].values()} |
|
|
|
needRebind = self.config['web'] != config['web'] #TODO only if there are changes to web.host or web.port; everything else can be updated without rebinding |
|
|
|
self.config = config |
|
|
|
if needRebind: |
|
|
|
self._configChanged.set() |
|
|
|
|
|
|
|
async def run(self, stopEvent): |
|
|
|
while True: |
|
|
|
runner = aiohttp.web.AppRunner(self._app) |
|
|
|
await runner.setup() |
|
|
|
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) |
|
|
|
await site.start() |
|
|
|
await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
await runner.cleanup() |
|
|
|
if stopEvent.is_set(): |
|
|
|
break |
|
|
|
self._configChanged.clear() |
|
|
|
|
|
|
|
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process |
|
|
|
# https://stackoverflow.com/questions/1180606/using-subprocess-popen-for-process-with-large-output |
|
|
|
# -> https://stackoverflow.com/questions/57730010/python-asyncio-subprocess-write-stdin-and-read-stdout-stderr-continuously |
|
|
|
|
|
|
|
async def post(self, request): |
|
|
|
self.logger.info(f'Received request {id(request)} from {request.remote!r} for {request.path!r} with body {(await request.read())!r}') |
|
|
|
try: |
|
|
|
channel, auth, module, moduleargs, overlongmode = self._paths[request.path] |
|
|
|
except KeyError: |
|
|
|
self.logger.info(f'Bad request {id(request)}: no path {request.path!r}') |
|
|
|
raise aiohttp.web.HTTPNotFound() |
|
|
|
if auth: |
|
|
|
authHeader = request.headers.get('Authorization') |
|
|
|
if not authHeader or authHeader != auth: |
|
|
|
self.logger.info(f'Bad request {id(request)}: authentication failed: {authHeader!r} != {auth}') |
|
|
|
raise aiohttp.web.HTTPForbidden() |
|
|
|
if module is not None: |
|
|
|
self.logger.debug(f'Processing request {id(request)} using {module!r}') |
|
|
|
try: |
|
|
|
message = await module.process(request, *moduleargs) |
|
|
|
except aiohttp.web.HTTPException as e: |
|
|
|
raise e |
|
|
|
except Exception as e: |
|
|
|
self.logger.error(f'Bad request {id(request)}: exception in module process function: {type(e).__module__}.{type(e).__name__}: {e!s}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
if '\r' in message or '\n' in message: |
|
|
|
self.logger.error(f'Bad request {id(request)}: module process function returned message with linebreaks: {message!r}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
else: |
|
|
|
self.logger.debug(f'Processing request {id(request)} using default processor') |
|
|
|
message = await self._default_process(request) |
|
|
|
self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue') |
|
|
|
self.messageQueue.put_nowait((channel, message, overlongmode)) |
|
|
|
raise aiohttp.web.HTTPOk() |
|
|
|
|
|
|
|
async def _default_process(self, request): |
|
|
|
try: |
|
|
|
message = await request.text() |
|
|
|
except Exception as e: |
|
|
|
self.logger.info(f'Bad request {id(request)}: exception while reading request data: {e!s}') |
|
|
|
raise aiohttp.web.HTTPBadRequest() # Yes, it's always the client's fault. :-) |
|
|
|
self.logger.debug(f'Request {id(request)} payload: {message!r}') |
|
|
|
# Strip optional [CR] LF at the end of the payload |
|
|
|
if message.endswith('\r\n'): |
|
|
|
message = message[:-2] |
|
|
|
elif message.endswith('\n'): |
|
|
|
message = message[:-1] |
|
|
|
if '\r' in message or '\n' in message: |
|
|
|
self.logger.info(f'Bad request {id(request)}: linebreaks in message') |
|
|
|
raise aiohttp.web.HTTPBadRequest() |
|
|
|
return message |
|
|
|
|
|
|
|
|
|
|
|
def configure_logging(config): |
|
|
|
#TODO: Replace with logging.basicConfig(..., force = True) (Py 3.8+) |
|
|
|
root = logging.getLogger() |
|
|
|
root.setLevel(getattr(logging, config['logging']['level'])) |
|
|
|
root.handlers = [] #FIXME: Undocumented attribute of logging.Logger |
|
|
|
formatter = logging.Formatter(config['logging']['format'], style = '{') |
|
|
|
stderrHandler = logging.StreamHandler() |
|
|
|
stderrHandler.setFormatter(formatter) |
|
|
|
root.addHandler(stderrHandler) |
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
|
if len(sys.argv) != 2: |
|
|
|
print('Usage: irclog.py CONFIGFILE', file = sys.stderr) |
|
|
|
sys.exit(1) |
|
|
|
configFile = sys.argv[1] |
|
|
|
config = Config(configFile) |
|
|
|
configure_logging(config) |
|
|
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
|
|
messageQueue = asyncio.Queue() |
|
|
|
|
|
|
|
irc = IRCClient(messageQueue, config) |
|
|
|
webserver = WebServer(config) |
|
|
|
storage = Storage(messageQueue, config) |
|
|
|
|
|
|
|
sigintEvent = asyncio.Event() |
|
|
|
def sigint_callback(): |
|
|
|
global logger |
|
|
|
nonlocal sigintEvent |
|
|
|
logger.info('Got SIGINT, stopping') |
|
|
|
sigintEvent.set() |
|
|
|
loop.add_signal_handler(signal.SIGINT, sigint_callback) |
|
|
|
|
|
|
|
def sigusr1_callback(): |
|
|
|
global logger |
|
|
|
nonlocal config, irc, webserver |
|
|
|
logger.info('Got SIGUSR1, reloading config') |
|
|
|
try: |
|
|
|
newConfig = config.reread() |
|
|
|
except InvalidConfig as e: |
|
|
|
logger.error(f'Config reload failed: {e!s} (old config remains active)') |
|
|
|
return |
|
|
|
config = newConfig |
|
|
|
configure_logging(config) |
|
|
|
irc.update_config(config) |
|
|
|
webserver.update_config(config) |
|
|
|
storage.update_config(config) |
|
|
|
loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback) |
|
|
|
|
|
|
|
await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent), storage.run(loop, sigintEvent)) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
asyncio.run(main()) |