@@ -3,13 +3,17 @@ import aiohttp.web
import asyncio
import base64
import collections
import concurrent.future s
import functool s
import importlib.util
import inspect
import ircstates
import irctokens
import itertools
import json
import logging
import os.path
import signal
import socket
import ssl
import string
import sys
@@ -53,14 +57,20 @@ async def wait_cancel_pending(aws, paws = None, **kwargs):
if paws is None:
paws = set()
tasks = aws | paws
logger.debug(f'waiting for {tasks!r}')
done, pending = await asyncio.wait(tasks, **kwargs)
logger.debug(f'done waiting for {tasks!r}; cancelling pending non-persistent tasks: {pending!r}')
for task in pending:
if task not in paws:
logger.debug(f'cancelling {task!r}')
task.cancel()
logger.debug(f'awaiting cancellation of {task!r}')
try:
await task
except asyncio.CancelledError:
pass
logger.debug(f'done cancelling {task!r}')
logger.debug(f'done wait_cancel_pending {tasks!r}')
return done, pending
@@ -92,7 +102,7 @@ class Config(dict):
except (ValueError, AssertionError) as e:
raise InvalidConfig('Invalid log format: parsing failed') from e
if 'irc' in obj:
if any(x not in ('host', 'port', 'ssl', 'nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
if any(x not in ('host', 'port', 'ssl', 'family', ' nick', 'real', 'certfile', 'certkeyfile') for x in obj['irc']):
raise InvalidConfig('Unknown key found in irc section')
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
raise InvalidConfig('Invalid IRC host')
@@ -100,6 +110,10 @@ class Config(dict):
raise InvalidConfig('Invalid IRC port')
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
if 'family' in obj['irc']:
if obj['irc']['family'] not in ('inet', 'INET', 'inet6', 'INET6'):
raise InvalidConfig('Invalid IRC family')
obj['irc']['family'] = getattr(socket, f'AF_{obj["irc"]["family"].upper()}')
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
raise InvalidConfig('Invalid IRC nick')
if len(IRCClientProtocol.nick_command(obj['irc']['nick'])) > 510:
@@ -192,7 +206,12 @@ class Config(dict):
raise InvalidConfig(f'Invalid map {key!r} overlongmode: unsupported value')
# Default values
finalObj = {'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'}, 'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
finalObj = {
'logging': {'level': 'INFO', 'format': '{asctime} {levelname} {name} {message}'},
'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'family': 0, 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None},
'web': {'host': '127.0.0.1', 'port': 8080},
'maps': {}
}
# Fill in default values for the maps
for key, map_ in obj['maps'].items():
@@ -253,7 +272,7 @@ class Config(dict):
class MessageQueue:
# An object holding onto the messages received from nodeping
# An object holding onto the messages received over HTTP for sending to IRC
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
# Differences to asyncio.Queue include:
@@ -310,12 +329,14 @@ class MessageQueue:
class IRCClientProtocol(asyncio.Protocol):
logger = logging.getLogger('http2irc.IRCClientProtocol')
def __init__(self, m essageQueue, connectionClosedEvent, loop, config, channels):
self.messageQueue = m essageQueue
def __init__(self, http2ircM essageQueue, connectionClosedEvent, loop, config, channels):
self.http2ircMessageQueue = http2ircM essageQueue
self.connectionClosedEvent = connectionClosedEvent
self.loop = loop
self.config = config
self.lastRecvTime = None
self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit
self.sendQueue = asyncio.Queue()
self.buffer = b''
self.connected = False
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)
@@ -323,7 +344,14 @@ class IRCClientProtocol(asyncio.Protocol):
self.pongReceivedEvent = asyncio.Event()
self.sasl = bool(self.config['irc']['certfile'] and self.config['irc']['certkeyfile'])
self.authenticated = False
self.usermask = None
self.server = ircstates.Server(self.config['irc']['host'])
self.capReqsPending = set() # Capabilities requested from the server but not yet ACKd or NAKd
self.caps = set() # Capabilities acknowledged by the server
self.whoxQueue = collections.deque() # Names of channels that were joined successfully but for which no WHO (WHOX) query was sent yet
self.whoxChannel = None # Name of channel for which a WHO query is currently running
self.whoxReply = [] # List of (nickname, account) tuples from the currently running WHO query
self.whoxStartTime = None
self.userChannels = collections.defaultdict(set) # List of which channels a user is known to be in; nickname:str -> {channel:str, ...}
@staticmethod
def nick_command(nick: str):
@@ -334,17 +362,16 @@ class IRCClientProtocol(asyncio.Protocol):
nickb = nick.encode('utf-8')
return b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + real.encode('utf-8')
def _maybe_set_usermask(self, usermask):
if b'@' in usermask and b'!' in usermask.split(b'@')[0] and all(x not in usermask for x in (b' ', b'*', b'#', b'&')):
self.usermask = usermask
self.logger.debug(f'Usermask is now {usermask!r}')
def connection_made(self, transport):
self.logger.info('IRC connected')
self.transport = transport
self.connected = True
caps = [b'multi-prefix', b'userhost-in-names', b'away-notify', b'account-notify', b'extended-join']
if self.sasl:
self.send(b'CAP REQ :sasl')
caps.append(b'sasl')
for cap in caps:
self.capReqsPending.add(cap.decode('ascii'))
self.send(b'CAP REQ :' + cap)
self.send(self.nick_command(self.config['irc']['nick']))
self.send(self.user_command(self.config['irc']['nick'], self.config['irc']['real']))
@@ -393,15 +420,41 @@ class IRCClientProtocol(asyncio.Protocol):
self._send_join_part(b'JOIN', channelsToJoin)
def send(self, data):
self.logger.debug(f'S end: {data!r}')
self.logger.debug(f'Queueing for s end: {data!r}')
if len(data) > 510:
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}')
self.sendQueue.put_nowait(data)
def _direct_send(self, data):
self.logger.debug(f'Send: {data!r}')
time_ = time.time()
self.transport.write(data + b'\r\n')
return time_
async def send_queue(self):
while True:
self.logger.debug('Trying to get data from send queue')
t = asyncio.create_task(self.sendQueue.get())
done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set():
break
assert t in done, f'{t!r} is not in {done!r}'
data = t.result()
self.logger.debug(f'Got {data!r} from send queue')
now = time.time()
if self.lastSentTime is not None and now - self.lastSentTime < 1:
self.logger.debug(f'Rate limited')
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now)
if self.connectionClosedEvent.is_set():
break
time_ = self._direct_send(data)
if self.lastSentTime is not None:
self.lastSentTime = time_
async def _get_message(self):
self.logger.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
messageFuture = asyncio.create_task(self.messageQueue.get())
done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = concurrent.futures.FIRST_COMPLETED)
self.logger.debug(f'Message queue {id(self.http2ircMessageQueue)} length: {self.http2ircM essageQueue.qsize()}')
messageFuture = asyncio.create_task(self.http2ircM essageQueue.get())
done, pending = await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, paws = {messageFuture}, return_when = asyn ci o.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set():
if messageFuture in pending:
self.logger.debug('Cancelling messageFuture')
@@ -413,11 +466,16 @@ class IRCClientProtocol(asyncio.Protocol):
pass
else:
# messageFuture is already done but we're stopping, so put the result back onto the queue
self.m essageQueue.putleft_nowait(messageFuture.result())
return None, None
self.http2ircM essageQueue.putleft_nowait(messageFuture.result())
return None, None, None
assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
return messageFuture.result()
def _self_usermask_length(self):
if not self.server.nickname or not self.server.username or not self.server.hostname:
return 100
return len(self.server.nickname) + len(self.server.username) + len(self.server.hostname)
async def send_messages(self):
while self.connected:
self.logger.debug(f'Trying to get a message')
@@ -427,7 +485,7 @@ class IRCClientProtocol(asyncio.Protocol):
break
channelB = channel.encode('utf-8')
messageB = message.encode('utf-8')
usermaskPrefixLength = 1 + (len(self.usermask) if self.usermask else 100 ) + 1
usermaskPrefixLength = 1 + self._self_usermask_length( ) + 1
if usermaskPrefixLength + len(b'PRIVMSG ' + channelB + b' :' + messageB) > 510:
# Message too long, need to split or truncate. First try to split on spaces, then on codepoints. Ideally, would use graphemes between, but that's too complicated.
self.logger.debug(f'Message too long, overlongmode = {overlongmode}')
@@ -466,20 +524,19 @@ class IRCClientProtocol(asyncio.Protocol):
messageB = message.encode('utf-8')
if overlongmode == 'split':
for msg in reversed(messages):
self.m essageQueue.putleft_nowait((channel, msg, overlongmode))
self.http2ircM essageQueue.putleft_nowait((channel, msg, overlongmode))
elif overlongmode == 'truncate':
self.m essageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
self.http2ircM essageQueue.putleft_nowait((channel, messages[0] + '…', overlongmode))
else:
self.logger.info(f'Sending {message!r} to {channel!r}')
self.unconfirmedMessages.append((channel, message, overlongmode))
self.send(b'PRIVMSG ' + channelB + b' :' + messageB)
await asyncio.sleep(1) # Rate limit
async def confirm_messages(self):
while self.connected:
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = concurrent.futures .FIRST_COMPLETED, timeout = 60) # Confirm once per minute
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyn ci o.FIRST_COMPLETED, timeout = 60) # Confirm once per minute
if not self.connected: # Disconnected while sleeping, can't confirm unconfirmed messages, requeue them directly
self.m essageQueue.putleft_nowait(*self.unconfirmedMessages)
self.http2ircM essageQueue.putleft_nowait(*self.unconfirmedMessages)
self.unconfirmedMessages = []
break
if not self.unconfirmedMessages:
@@ -488,18 +545,19 @@ class IRCClientProtocol(asyncio.Protocol):
self.logger.debug('Trying to confirm message delivery')
self.pongReceivedEvent.clear()
self.send(b'PING :42')
await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = concurrent.futures .FIRST_COMPLETED, timeout = 5)
await wait_cancel_pending({asyncio.create_task(self.pongReceivedEvent.wait())}, return_when = asyn ci o.FIRST_COMPLETED, timeout = 5)
self.logger.debug(f'Message delivery successful: {self.pongReceivedEvent.is_set()}')
if not self.pongReceivedEvent.is_set():
# No PONG received in five seconds, assume connection's dead
self.logger.warning(f'Message delivery confirmation failed, putting {len(self.unconfirmedMessages)} messages back into the queue')
self.m essageQueue.putleft_nowait(*self.unconfirmedMessages)
self.http2ircM essageQueue.putleft_nowait(*self.unconfirmedMessages)
self.transport.close()
self.unconfirmedMessages = []
def data_received(self, data):
time_ = time.time()
self.logger.debug(f'Data received: {data!r}')
self.lastRecvTime = time.time()
self.lastRecvTime = time_
# If there's any data left in the buffer, prepend it to the data. Split on CRLF.
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
@@ -507,104 +565,146 @@ class IRCClientProtocol(asyncio.Protocol):
data = self.buffer + data
messages = data.split(b'\r\n')
for message in messages[:-1]:
self.message_received(message)
lines = self.server.recv(message + b'\r\n')
assert len(lines) == 1, f'recv did not return exactly one line: {message!r} -> {lines!r}'
self.message_received(time_, message, lines[0])
self.server.parse_tokens(lines[0])
self.buffer = messages[-1]
def message_received(self, message):
self.logger.debug(f'Message received: {message!r}')
rawMessage = message
if message.startswith(b':') and b' ' in message:
# Prefixed message, extract command + parameters (the prefix cannot contain a space)
message = message.split(b' ', 1)[1]
def message_received(self, time_, message, line):
self.logger.debug(f'Message received at {time_}: {message!r}')
maybeTriggerWhox = False
# PING/PONG
if message.startswith(b'PING ') :
self.send(b'PONG ' + message[5:] )
elif message.startswith(b'PONG ') :
if line.command == 'PING' :
self._direct_send(irctokens.build('PONG', line.params).format().encode('utf-8') )
elif line.command == 'PONG' :
self.pongReceivedEvent.set()
# SASL
elif message.startswith(b'CAP ') and self.sasl:
if message[message.find(b' ', 4) + 1:] == b'ACK :sasl':
self.send(b'AUTHENTICATE EXTERNAL')
else:
self.logger.error(f'Received unexpected CAP reply {message!r}, terminating connection')
self.transport.close()
elif message == b'AUTHENTICATE +':
# IRCv3 and SASL
elif line.command == 'CAP':
if line.params[1] == 'ACK':
for cap in line.params[2].split(' '):
self.logger.debug(f'CAP ACK: {cap}')
self.caps.add(cap)
if cap == 'sasl' and self.sasl:
self.send(b'AUTHENTICATE EXTERNAL')
else:
self.capReqsPending.remove(cap)
elif line.params[1] == 'NAK':
self.logger.warning(f'Failed to activate CAP(s): {line.params[2]}')
for cap in line.params[2].split(' '):
self.capReqsPending.remove(cap)
if len(self.capReqsPending) == 0:
self.send(b'CAP END')
elif line.command == 'AUTHENTICATE' and line.params == ['+']:
self.send(b'AUTHENTICATE +')
elif message.startswith(b'900 '): # "You are now logged in", includes the usermask
words = message.split(b' ')
if len(words) >= 3 and b'!' in words[2] and b'@' in words[2]:
if b'!~' not in words[2]:
# At least Charybdis seems to always return the user without a tilde, even if identd failed. Assume no identd and account for that extra tilde.
words[2] = words[2].replace(b'!', b'!~', 1)
self._maybe_set_usermask(words[2])
elif message.startswith(b'903 '): # SASL auth successful
elif line.command == ircstates.numerics.RPL_SASLSUCCESS:
self.authenticated = True
self.send(b'CAP END')
elif any(message.startswith(x) for x in (b'902 ', b'904 ', b'905 ', b'906 ', b'908 ')):
self.capReqsPending.remove('sasl')
if len(self.capReqsPending) == 0:
self.send(b'CAP END')
elif line.command in ('902', ircstates.numerics.ERR_SASLFAIL, ircstates.numerics.ERR_SASLTOOLONG, ircstates.numerics.ERR_SASLABORTED, ircstates.numerics.RPL_SASLMECHS):
self.logger.error('SASL error, terminating connection')
self.transport.close()
# NICK errors
elif any(message.startswith(x) for x in (b'431 ', b'432 ', b'433 ', b'436 ') ):
elif line.command in ('431', ircstates.numerics.ERR_ERRONEUSNICKNAME, ircstates.numerics.ERR_NICKNAMEINUSE, '436' ):
self.logger.error(f'Failed to set nickname: {message!r}, terminating connection')
self.transport.close()
# USER errors
elif any(message.startswith(x) for x in (b'461 ', b'462 ') ):
elif line.command in ('461', '462' ):
self.logger.error(f'Failed to register: {message!r}, terminating connection')
self.transport.close()
# JOIN errors
elif any(message.startswith(x) for x in (b'405 ', b'471 ', b'473 ', b'474 ', b'475 ')):
elif line.command in (
ircstates.numerics.ERR_TOOMANYCHANNELS,
ircstates.numerics.ERR_CHANNELISFULL,
ircstates.numerics.ERR_INVITEONLYCHAN,
ircstates.numerics.ERR_BANNEDFROMCHAN,
ircstates.numerics.ERR_BADCHANNELKEY,
):
self.logger.error(f'Failed to join channel: {message!r}, terminating connection')
self.transport.close()
# PART errors
elif message.startswith(b'442 '):
elif line.command == '442' :
self.logger.error(f'Failed to part channel: {message!r}')
# JOIN/PART errors
elif message.startswith(b'403 ') :
elif line.command == ircstates.numerics.ERR_NOSUCHCHANNEL :
self.logger.error(f'Failed to join or part channel: {message!r}')
# PRIVMSG errors
elif any(message.startswith(x) for x in (b'401 ', b'404 ', b'407 ', b'411 ', b'412 ', b'413 ', b'414 ') ):
elif line.command in (ircstates.numerics.ERR_NOSUCHNICK, '404', '407', '411', '412', '413', '414' ):
self.logger.error(f'Failed to send message: {message!r}')
# Connection registration reply
elif message.startswith(b'001 ') :
elif line.command == ircstates.numerics.RPL_WELCOME :
self.logger.info('IRC connection registered')
if self.sasl and not self.authenticated:
self.logger.error('IRC connection registered but not authenticated, terminating connection')
self.transport.close()
return
self.lastSentTime = time.time()
self._send_join_part(b'JOIN', self.channels)
asyncio.create_task(self.send_messages())
asyncio.create_task(self.confirm_messages())
# JOIN success
elif message.startswith(b'JOIN ') and not self.usermask:
# If this is my own join message, it should contain the usermask in the prefix
if rawMessage.startswith(b':' + self.config['irc']['nick'].encode('utf-8') + b'!') and b' ' in rawMessage:
usermask = rawMessage.split(b' ', 1)[0][1:]
self._maybe_set_usermask(usermask)
# Services host change
elif message.startswith(b'396 '):
words = message.split(b' ')
if len(words) >= 3:
# Sanity check inspired by irssi src/irc/core/irc-servers.c
if not any(x in words[2] for x in (b'*', b'?', b'!', b'#', b'&', b' ')) and not any(words[2].startswith(x) for x in (b'@', b':', b'-')) and words[2][-1:] != b'-':
if b'@' in words[2]: # user@host
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + words[2])
else: # host (get user from previous mask or settings)
if self.usermask:
user = self.usermask.split(b'@')[0].split(b'!')[1]
else:
user = b'~' + self.config['irc']['nick'].encode('utf-8')
self._maybe_set_usermask(self.config['irc']['nick'].encode('utf-8') + b'!' + user + b'@' + words[2])
# Bot getting KICKed
elif line.command == 'KICK' and line.source and self.server.casefold(line.params[1]) == self.server.casefold(self.server.nickname):
self.logger.warning(f'Got kicked from {line.params[0]}')
kickedChannel = self.server.casefold(line.params[0])
for channel in self.channels:
if self.server.casefold(channel) == kickedChannel:
self.channels.remove(channel)
break
# WHOX on successful JOIN if supported to fetch account information
elif line.command == 'JOIN' and self.server.isupport.whox and line.source and self.server.casefold(line.hostmask.nickname) == self.server.casefold(self.server.nickname):
self.whoxQueue.extend(line.params[0].split(','))
maybeTriggerWhox = True
# WHOX response
elif line.command == ircstates.numerics.RPL_WHOSPCRPL and line.params[1] == '042':
self.whoxReply.append({'nick': line.params[4], 'hostmask': f'{line.params[4]}!{line.params[2]}@{line.params[3]}', 'account': line.params[5] if line.params[5] != '0' else None})
# End of WHOX response
elif line.command == ircstates.numerics.RPL_ENDOFWHO:
# Patch ircstates account info; ircstates does not parse the WHOX reply itself.
for entry in self.whoxReply:
if entry['account']:
self.server.users[self.server.casefold(entry['nick'])].account = entry['account']
self.whoxChannel = None
self.whoxReply = []
self.whoxStartTime = None
maybeTriggerWhox = True
# General fatal ERROR
elif line.command == 'ERROR':
self.logger.error(f'Server sent ERROR: {message!r}')
self.transport.close()
# Send next WHOX if appropriate
if maybeTriggerWhox and self.whoxChannel is None and self.whoxQueue:
self.whoxChannel = self.whoxQueue.popleft()
self.whoxReply = []
self.whoxStartTime = time.time() # Note, may not be the actual start time due to rate limiting
self.send(b'WHO ' + self.whoxChannel.encode('utf-8') + b' c%tuhna,042')
async def quit(self):
# The server acknowledges a QUIT by sending an ERROR and closing the connection. The latter triggers connection_lost, so just wait for the closure event.
self.logger.info('Quitting')
self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue
self._direct_send(b'QUIT :Bye')
await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10)
if not self.connectionClosedEvent.is_set():
self.logger.error('Quitting cleanly did not work, closing connection forcefully')
# Event will be set implicitly in connection_lost.
self.transport.close()
def connection_lost(self, exc):
self.logger.info('IRC connection lost')
@@ -615,8 +715,8 @@ class IRCClientProtocol(asyncio.Protocol):
class IRCClient:
logger = logging.getLogger('http2irc.IRCClient')
def __init__(self, m essageQueue, config):
self.messageQueue = m essageQueue
def __init__(self, http2ircM essageQueue, config):
self.http2ircMessageQueue = http2ircM essageQueue
self.config = config
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
@@ -647,17 +747,43 @@ class IRCClient:
while True:
connectionClosedEvent.clear()
try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
self.logger.debug('Creating IRC connection')
t = asyncio.create_task(loop.create_connection(
protocol_factory = lambda: IRCClientProtocol(self.http2ircMessageQueue, connectionClosedEvent, loop, self.config, self.channels),
host = self.config['irc']['host'],
port = self.config['irc']['port'],
ssl = self._get_ssl_context(),
family = self.config['irc']['family'],
))
# No automatic cancellation of t because it's handled manually below.
done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30)
if t not in done:
t.cancel()
await t # Raises the CancelledError
self._transport, self._protocol = t.result()
self.logger.debug('Starting send queue processing')
sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent
self.logger.debug('Waiting for connection closure or SIGINT')
try:
await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = concurrent.futures.FIRST_COMPLETED)
await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyn ci o.FIRST_COMPLETED)
finally:
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}')
if not connectionClosedEvent.is_set():
self.logger.debug('Quitting connection')
await self._protocol.quit()
if not sendTask.done():
sendTask.cancel()
try:
await sendTask
except asyncio.CancelledError:
pass
self._transport = None
self._protocol = None
except (ConnectionRefusedError, asyncio.TimeoutError) as e:
self.logger.error(str(e))
except (ConnectionError, ssl.SSLError, asyncio.TimeoutError, asyncio.Cancelled Error) as e:
self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}' )
await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5)
if sigintEvent.is_set():
self.logger.debug('Got SIGINT, breaking IRC loop')
break
@property
@@ -668,8 +794,8 @@ class IRCClient:
class WebServer:
logger = logging.getLogger('http2irc.WebServer')
def __init__(self, m essageQueue, ircClient, config):
self.messageQueue = m essageQueue
def __init__(self, http2ircM essageQueue, ircClient, config):
self.http2ircMessageQueue = http2ircM essageQueue
self.ircClient = ircClient
self.config = config
@@ -697,7 +823,7 @@ class WebServer:
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start()
await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = concurrent.futures .FIRST_COMPLETED)
await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyn ci o.FIRST_COMPLETED)
await runner.cleanup()
if stopEvent.is_set():
break
@@ -735,7 +861,7 @@ class WebServer:
self.logger.debug(f'Processing request {id(request)} using default processor')
message = await self._default_process(request)
self.logger.info(f'Accepted request {id(request)}, putting message {message!r} for {channel} into message queue')
self.m essageQueue.put_nowait((channel, message, overlongmode))
self.http2ircM essageQueue.put_nowait((channel, message, overlongmode))
raise aiohttp.web.HTTPOk()
async def _default_process(self, request):
@@ -777,10 +903,10 @@ async def main():
loop = asyncio.get_running_loop()
m essageQueue = MessageQueue()
http2ircM essageQueue = MessageQueue()
irc = IRCClient(m essageQueue, config)
webserver = WebServer(m essageQueue, irc, config)
irc = IRCClient(http2ircM essageQueue, config)
webserver = WebServer(http2ircM essageQueue, irc, config)
sigintEvent = asyncio.Event()
def sigint_callback():