Browse Source

There was an attempt...

master
JustAnotherArchivist 3 years ago
parent
commit
3834d9d124
1 changed files with 164 additions and 22 deletions
  1. +164
    -22
      irclog.py

+ 164
- 22
irclog.py View File

@@ -3,7 +3,6 @@ import aiohttp.web
import asyncio
import base64
import collections
import concurrent.futures
import importlib.util
import inspect
import itertools
@@ -20,6 +19,8 @@ import toml

logger = logging.getLogger('irclog')
SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}
messageConnectionClosed = object() # Signals that the connection was closed by either the bot or the server
messageEOF = object() # Special object to signal the end of messages to Storage


class InvalidConfig(Exception):
@@ -82,6 +83,7 @@ class Config(dict):
if 'path' in obj['storage']:
obj['storage']['path'] = os.path.abspath(os.path.join(os.path.dirname(self._filename), obj['storage']['path']))
try:
#TODO This doesn't seem to work correctly; doesn't fail when the dir is -w
f = tempfile.TemporaryFile(dir = obj['storage']['path'])
f.close()
except (OSError, IOError) as e:
@@ -194,6 +196,7 @@ class IRCClientProtocol(asyncio.Protocol):
self.buffer = b''
self.connected = False
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
self.authenticated = False
self.usermask = None
@@ -207,6 +210,25 @@ class IRCClientProtocol(asyncio.Protocol):
nickb = nick.encode('utf-8')
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')

@staticmethod
def valid_channel(channel: str):
return channel[0] in ('#', '&') and not any(x in channel for x in (' ', '\x00', '\x07', '\r', '\n', ','))

@staticmethod
def valid_nick(nick: str):
# According to RFC 1459, a nick must be '<letter> { <letter> | <number> | <special> }'. This is obviously not true in practice because <special> doesn't include underscores, for example.
# So instead, just do a sanity check similar to the channel one to disallow obvious bullshit.
return not any(x in nick for x in (' ', '\x00', '\x07', '\r', '\n', ','))

@staticmethod
def prefix_to_nick(prefix: str):
nick = prefix[1:]
if '!' in nick:
nick = nick.split('!', 1)[0]
if '@' in nick: # nick@host is also legal
nick = nick.split('@', 1)[0]
return nick

def _maybe_set_usermask(self, usermask):
if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')):
self.usermask = usermask
@@ -271,7 +293,7 @@ class IRCClientProtocol(asyncio.Protocol):
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
time_ = time.time()
self.transport.write(data + b'\r\n')
self.messageQueue.put_nowait((time_, b'> ' + data))
self.messageQueue.put_nowait((time_, b'> ' + data, None))

def data_received(self, data):
self.logger.debug(f'Data received: {data!r}')
@@ -290,12 +312,15 @@ class IRCClientProtocol(asyncio.Protocol):
def message_received(self, time_, message):
self.logger.debug(f'Message received at {time_}: {message!r}')
rawMessage = message
hasPrefix = False
if message.startswith(b':') and b' ' in message:
# Prefixed message, extract command + parameters (the prefix cannot contain a space)
message = message.split(b' ', 1)[1]
prefix, message = message.split(b' ', 1)
hasPrefix = True

# Queue message for storage
self.messageQueue.put_nowait((time_, b'< ' + rawMessage))
# Queue message for storage, except QUITs and NICKs which are handled below with user tracking
if not message.startswith(b'QUIT ') and message != b'QUIT' and not message.startswith(b'NICK '):
self.messageQueue.put_nowait((time_, b'< ' + rawMessage, None))

# PING/PONG
if message.startswith(b'PING '):
@@ -382,10 +407,94 @@ class IRCClientProtocol(asyncio.Protocol):
user = b'~' + self.config['irc']['nick'].encode('utf-8')
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2])

# User tracking (for NICK and QUIT)
decoded = False
if any(message.startswith(x) for x in (b'353 ', b'JOIN ', b'PART ', b'KICK ', b'NICK ', b'QUIT ')) or message == b'QUIT':
try:
if hasPrefix:
prefixStr = prefix.decode('utf-8')
messageStr = message.decode('utf-8')
except UnicodeDecodeError as e:
self.logger.warning(f'Could not decode prefix/message {prefix!r}/{message!r} ({e!s}), user tracking may be wrong')
else:
decoded = True
if message.startswith(b'353 ') and decoded: # RPL_NAMREPLY
_, channel, nicksStr = messageStr.split(' ', 2)
if nicksStr.startswith(':'): # It always should, but who knows...
nicksStr = nicksStr[1:]
nicks = nicksStr.split(' ')
for nick in nicks:
if nick[0] in ('@', '+'):
nick = nick[1:]
if self.valid_channel(channel) and self.valid_nick(nick):
self.userChannels[nick].add(channel)
if (message.startswith(b'JOIN ') or message.startswith(b'PART ')) and decoded and hasPrefix:
nick = self.prefix_to_nick(prefixStr)
channels = messageStr[5:] # Could be more than one channel in theory
for channel in channels.split(','):
if self.valid_channel(channel) and self.valid_nick(nick):
if message.startswith(b'JOIN '):
self.userChannels[nick].add(channel)
else:
self.userChannels[nick].discard(channel)
if message.startswith(b'KICK ') and decoded: # Prefix is supposed to indicate who kicked the user, but we don't care about that for the user tracking.
_, channel, nick = messageStr.split(' ', 2)
if ' ' in nick: # There might be a kick reason after the nick
nick = nick.split(' ', 1)[0]
if self.valid_channel(channel) and self.valid_nick(nick):
self.userChannels[nick].discard(channel)
if message.startswith(b'NICK '):
# If something can't be processed, just send it to storage without user tracking.
sendGeneric = True
if decoded and hasPrefix:
oldNick = self.prefix_to_nick(prefixStr)
newNick = message[5:]
if self.valid_nick(oldNick) and self.valid_nick(newNick) and oldNick in self.userChannels:
self.userChannels[newNick] = self.userChannels[oldNick]
del self.userChannels[oldNick]
if self.userChannels[newNick]:
sendGeneric = False
self.messageQueue.put_nowait((time_, rawMessage, self.userChannels[newNick]))
if sendGeneric:
self.logger.warning(f'Could not process nick change {rawMessage!r}, user tracking may be wrong')
self.messageQueue.put_nowait((time_, rawMessage, None))
if message.startswith(b'QUIT ') or message == b'QUIT':
# Technically a simple 'QUIT' is not legal per RFC 1459. That's because there must always be a space after the command due to how <params> is defined.
# In practice, it is accepted by ircds though, so it can presumably also be received by a client.
sendGeneric = True
if decoded and hasPrefix:
nick = self.prefix_to_nick(prefixStr)
if nick != self.config['irc']['nick'] and nick in self.userChannels:
if self.userChannels[nick]:
sendGeneric = False
self.messageQueue.put_nowait((time_, rawMessage, self.userChannels[nick]))
del self.userChannels[nick]
if not hasPrefix or (decoded and hasPrefix and nick == self.config['irc']['nick']):
# Oh no, *I* am getting disconnected! :-(
# I'm not actually sure whether the prefix version can happen, but better safe than sorry...
# In this case, it should be logged to all channels as well as the general log. The extra 'general' entry triggers Storage's code to write a message to the general log.
# Side effect: if the connection dies before any channels were joined, this causes the quit to be logged everywhere. However, there won't be a JOIN in the log, so it would still be unambiguous.
# Also, the connection loss after the disconnect triggers another message to be written to the logs. ¯\_(ツ)_/¯
sendGeneric = False
self.messageQueue.put_nowait((time_, rawMessage, list(self.channels) + ['general']))
if sendGeneric:
self.logger.warning(f'Could not process quit message {rawMessage!r}, user tracking may be wrong')
self.messageQueue.put_nowait((time_, rawMessage, None))

async def quit(self):
# It appears to be hard to impossible to send a clean quit, wait for it to be actually sent, and only then close the transport.
# This is because asyncio.sslproto.SSLTransport doesn't support explicit draining nor waiting for an empty write queue nor write_eof.
# So instead, just close the transport and wait until connection_lost is triggered (which also puts a message in the logs).
self.logger.info('Quitting')
self.transport.close()
await self.connectionClosedEvent.wait()

def connection_lost(self, exc):
time_ = time.time()
self.logger.info('IRC connection lost')
self.connected = False
self.connectionClosedEvent.set()
self.messageQueue.put_nowait((time_, messageConnectionClosed, list(self.channels) + ['general']))


class IRCClient:
@@ -425,13 +534,16 @@ class IRCClient:
try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
try:
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED)
finally:
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
if not connectionClosedEvent.is_set():
await self._protocol.quit()
except (ConnectionRefusedError, asyncio.TimeoutError) as e:
self.logger.error(str(e))
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED)
if sigintEvent.is_set():
self.logger.debug('Got SIGINT, putting EOF and breaking')
self.messageQueue.put_nowait(messageEOF)
break


@@ -456,28 +568,40 @@ class Storage:
del self.files[channel]

#TODO mkdir as required
#TODO month

for channel in self.config['channels'].values():
if channel['ircchannel'] not in self.files and channel['active']:
self.files[channel['ircchannel']] = open(os.path.join(self.config['storage']['path'], channel['ircchannel'], '2020-10.log'), 'ab')

if None not in self.files:
self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab') #TODO Month
self.files[None] = open(os.path.join(self.config['storage']['path'], 'general', '2020-10.log'), 'ab')

async def run(self, loop, sigintEvent):
self.update_config(self.config) # Ensure that files are open etc.
#TODO Task to rotate log files at the beginning of a new month
storageTask = asyncio.create_task(self.store_messages(sigintEvent))
flushTask = asyncio.create_task(self.flush_files(sigintEvent))
flushTask = asyncio.create_task(self.flush_files())
await sigintEvent.wait()
self.logger.debug('Got SIGINT, waiting for remaining messages to be stored')
await storageTask # Wait until everything's stored
self.active = False
#TODO Wait for tasks
self.logger.debug('Waiting for flush task')
await flushTask
self.close()

async def store_messages(self, sigintEvent):
while self.active:
#TODO wait for sigint as well
time_, rawMessage = await self.messageQueue.get()
self.logger.debug('Waiting for message')
res = await self.messageQueue.get()
self.logger.debug(f'Got {res!r} from message queue')
if res is messageEOF:
self.logger.debug('Message EOF, breaking store_messages loop')
break

time_, rawMessage, channels = res
if rawMessage is messageConnectionClosed:
rawMessage = b'- Connection closed'
message = rawMessage[2:] # Remove leading > or <
if message.startswith(b':') and b' ' in message:
prefix, message = message.split(b' ', 1)
@@ -492,9 +616,15 @@ class Storage:
for channel in channels:
self.store_message(time_, rawMessage, channel)
continue
if message.startswith(b'QUIT '):
#TODO Need to keep track of users to figure out in which channels they were... Ugh
pass
if message.startswith(b'QUIT ') or message == b'QUIT' or message.startswith(b'NICK '):
# If channels is not None, IRCClientProtocol managed to track the user and identify the channels this needs to be logged to.
# If it isn't, there might be channels in there (for some odd reason?) that are not being logged. In that case, emit one and only one message to the general log as well.
if channels is not None:
for channel in channels:
self.store_message(time_, rawMessage, channel, redirectToGeneral = False)
if channels is None or any(channel not in self.files for channel in channels):
self.store_message(time_, rawMessage, None)
continue
if message.startswith(b'MODE #') or message.startswith(b'MODE &') or message.startswith(b'KICK '):
channel = message.split(b' ', 2)[1]
channel = self.decode_channel(time_, rawMessage, channel)
@@ -502,10 +632,18 @@ class Storage:
continue
self.store_message(time_, rawMessage, channel)
continue
self.store_message(time_, rawMessage, None)
if channels is not None:
for channel in channels:
self.store_message(time_, rawMessage, channel)
else:
self.store_message(time_, rawMessage, None)

def store_message(self, time_, rawMessage, targetChannel):
def store_message(self, time_, rawMessage, targetChannel, redirectToGeneral = True):
self.logger.debug(f'Logging {rawMessage!r} at {time_} for {targetChannel!r}')
if targetChannel is not None and targetChannel not in self.files:
self.logger.debug(f'Target channel {targetChannel!r} not opened, redirecting to general log is {redirectToGeneral}')
if not redirectToGeneral:
return
targetChannel = None
self.files[targetChannel].write(str(time_).encode('ascii') + b' ' + rawMessage + b'\r\n')

@@ -519,9 +657,10 @@ class Storage:
self.store_message(time_, rawMessage, None)
return None

async def flush_files(self, sigintEvent):
async def flush_files(self):
while self.active:
await sigintEvent.wait()
await asyncio.sleep(1)
self.logger.debug('Exiting flush_files')

def close(self):
for f in self.files.values():
@@ -556,7 +695,7 @@ class WebServer:
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start()
await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
await asyncio.wait((stopEvent.wait(), self._configChanged.wait()), return_when = asyncio.FIRST_COMPLETED)
await runner.cleanup()
if stopEvent.is_set():
break
@@ -637,6 +776,9 @@ async def main():
loop = asyncio.get_running_loop()

messageQueue = asyncio.Queue()
# tuple(time: float, message: bytes or None, channels: list[str] or None)
# message = None indicates a connection loss
# channels = None indicates that IRCClientProtocol did not identify which channels are affected; it is only not None for QUIT or NICK messages.

irc = IRCClient(messageQueue, config)
webserver = WebServer(config)
@@ -652,7 +794,7 @@ async def main():

def sigusr1_callback():
global logger
nonlocal config, irc, webserver
nonlocal config, irc, webserver, storage
logger.info('Got SIGUSR1, reloading config')
try:
newConfig = config.reread()


Loading…
Cancel
Save