Explorar el Código

Refactor into more flexible tool supporting multiple endpoints and channels (nodeping2irc -> http2irc) with no-downtime config changes using a TOML config file and SIGUSR1 to reload and adapt accordingly

master
JustAnotherArchivist hace 4 años
padre
commit
e9a7780450
Se han modificado 2 ficheros con 375 adiciones y 246 borrados
  1. +375
    -0
      http2irc.py
  2. +0
    -246
      nodeping2irc.py

+ 375
- 0
http2irc.py Ver fichero

@@ -0,0 +1,375 @@
import aiohttp
import aiohttp.web
import asyncio
import base64
import collections
import concurrent.futures
import json
import logging
import signal
import ssl
import sys
import toml
import types


logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')


SSL_CONTEXTS = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}


class InvalidConfig(Exception):
'''Error in configuration file'''


def _mapping_to_namespace(d):
'''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively'''
return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()})


class Config:
def __init__(self, filename):
self._filename = filename
# Set below:
self.irc = None
self.web = None
self.maps = None

with open(self._filename, 'r') as fp:
obj = toml.load(fp)

logging.info(repr(obj))

# Sanity checks
if any(x not in ('irc', 'web', 'maps') for x in obj.keys()):
raise InvalidConfig('Unknown sections found in base object')
if any(not isinstance(x, collections.abc.Mapping) for x in obj.values()):
raise InvalidConfig('Invalid section type(s), expected objects/dicts')
if 'irc' in obj:
if any(x not in ('host', 'port', 'ssl', 'nick', 'real') for x in obj['irc']):
raise InvalidConfig('Unknown key found in irc section')
if 'host' in obj['irc'] and not isinstance(obj['irc']['host'], str): #TODO: Check whether it's a valid hostname
raise InvalidConfig('Invalid IRC host')
if 'port' in obj['irc'] and (not isinstance(obj['irc']['port'], int) or not 1 <= obj['irc']['port'] <= 65535):
raise InvalidConfig('Invalid IRC port')
if 'ssl' in obj['irc'] and obj['irc']['ssl'] not in ('yes', 'no', 'insecure'):
raise InvalidConfig(f'Invalid IRC SSL setting: {obj["irc"]["ssl"]!r}')
if 'nick' in obj['irc'] and not isinstance(obj['irc']['nick'], str): #TODO: Check whether it's a valid nickname
raise InvalidConfig('Invalid IRC nick')
if 'real' in obj['irc'] and not isinstance(obj['irc']['real'], str):
raise InvalidConfig('Invalid IRC realname')
if 'web' in obj:
if any(x not in ('host', 'port') for x in obj['web']):
raise InvalidConfig('Unknown key found in web section')
if 'host' in obj['web'] and not isinstance(obj['web']['host'], str): #TODO: Check whether it's a valid hostname (must resolve I guess?)
raise InvalidConfig('Invalid web hostname')
if 'port' in obj['web'] and (not isinstance(obj['web']['port'], int) or not 1 <= obj['web']['port'] <= 65535):
raise InvalidConfig('Invalid web port')
if 'maps' in obj:
for key, map_ in obj['maps'].items():
# Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace.
#TODO: Support for fancier identifiers (PEP 3131)?
if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '':
raise InvalidConfig(f'Invalid map key {key!r}')
if not isinstance(map_, collections.abc.Mapping):
raise InvalidConfig(f'Invalid map for {key!r}')
if any(x not in ('webpath', 'ircchannel', 'auth') for x in map_):
raise InvalidConfig(f'Unknown key(s) found in map {key!r}')
#TODO: Check values

# Default values
self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.'}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}

# Fill in default values for the maps
for key, map_ in obj['maps'].items():
if 'webpath' not in map_:
map_['webpath'] = f'/{key}'
if 'ircchannel' not in map_:
map_['ircchannel'] = f'#{key}'
if 'auth' not in map_:
map_['auth'] = False

# Merge in what was read from the config file and convert to SimpleNamespace
for key in ('irc', 'web', 'maps'):
if key in obj:
self._obj[key].update(obj[key])
setattr(self, key, _mapping_to_namespace(self._obj[key]))

def __repr__(self):
return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})'

def reread(self):
return Config(self._filename)


class MessageQueue:
# An object holding onto the messages received from nodeping
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
# Differences to asyncio.Queue include:
# - No maxsize
# - No put coroutine (not necessary since the queue can never be full)
# - Only one concurrent getter
# - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)

def __init__(self):
self._getter = None # None | asyncio.Future
self._queue = collections.deque()

async def get(self):
if self._getter is not None:
raise RuntimeError('Cannot get concurrently')
if len(self._queue) == 0:
self._getter = asyncio.get_running_loop().create_future()
logging.debug('Awaiting getter')
try:
await self._getter
except asyncio.CancelledError:
logging.debug('Cancelled getter')
self._getter = None
raise
logging.debug('Awaited getter')
self._getter = None
# For testing the cancellation/putting back onto the queue
#logging.debug('Delaying message queue get')
#await asyncio.sleep(3)
#logging.debug('Done delaying')
return self.get_nowait()

def get_nowait(self):
if len(self._queue) == 0:
raise asyncio.QueueEmpty
return self._queue.popleft()

def put_nowait(self, item):
self._queue.append(item)
if self._getter is not None:
self._getter.set_result(None)

def putleft_nowait(self, item):
self._queue.appendleft(item)
if self._getter is not None:
self._getter.set_result(None)

def qsize(self):
return len(self._queue)


class IRCClientProtocol(asyncio.Protocol):
def __init__(self, messageQueue, connectionClosedEvent, loop, config, channels):
logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {connectionClosedEvent}, {loop}')
self.messageQueue = messageQueue
self.connectionClosedEvent = connectionClosedEvent
self.loop = loop
self.config = config
self.buffer = b''
self.connected = False
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str)

def connection_made(self, transport):
logging.info('Connected')
self.transport = transport
self.connected = True
nickb = self.config.irc.nick.encode('utf-8')
self.send(b'NICK ' + nickb)
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8'))
self.send(b'JOIN ' + ','.join(self.channels).encode('utf-8')) #TODO: Split if too long
asyncio.create_task(self.send_messages())

def update_channels(self, channels: set):
channelsToPart = self.channels - channels
channelsToJoin = channels - self.channels
self.channels = channels

if self.connected:
if channelsToPart:
#TODO: Split if too long
self.send(b'PART ' + ','.join(channelsToPart).encode('utf-8'))
if channelsToJoin:
self.send(b'JOIN ' + ','.join(channelsToJoin).encode('utf-8'))

def send(self, data):
logging.info(f'Send: {data!r}')
self.transport.write(data + b'\r\n')

async def _get_message(self):
logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
messageFuture = asyncio.create_task(self.messageQueue.get())
done, pending = await asyncio.wait((messageFuture, self.connectionClosedEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
if self.connectionClosedEvent.is_set():
if messageFuture in pending:
logging.debug('Cancelling messageFuture')
messageFuture.cancel()
try:
await messageFuture
except asyncio.CancelledError:
logging.debug('Cancelled messageFuture')
pass
else:
# messageFuture is already done but we're stopping, so put the result back onto the queue
self.messageQueue.putleft_nowait(messageFuture.result())
return None, None
assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
return messageFuture.result()

async def send_messages(self):
while self.connected:
logging.debug(f'{id(self)}: trying to get a message')
channel, message = await self._get_message()
logging.debug(f'{id(self)}: got message: {message!r}')
if message is None:
break
self.send(b'PRIVMSG ' + channel.encode('utf-8') + b' :' + message.encode('utf-8'))
#TODO self.messageQueue.putleft_nowait if delivery fails
await asyncio.sleep(1) # Rate limit

def data_received(self, data):
logging.debug(f'Data received: {data!r}')
# Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
messages = data.split(b'\r\n')
if self.buffer:
self.message_received(self.buffer + messages[0])
messages = messages[1:]
for message in messages[:-1]:
self.message_received(message)
self.buffer = messages[-1]

def message_received(self, message):
logging.info(f'Message received: {message!r}')
if message.startswith(b'PING '):
self.send(b'PONG ' + message[5:])

def connection_lost(self, exc):
logging.info('The server closed the connection')
self.connected = False
self.connectionClosedEvent.set()


class IRCClient:
def __init__(self, messageQueue, config):
self.messageQueue = messageQueue
self.config = config
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}

self._transport = None
self._protocol = None

def update_config(self, config):
needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl)
self.config = config
if self._transport: # if currently connected:
if needReconnect:
self._transport.close()
else:
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
self._protocol.update_channels(self.channels)

async def run(self, loop, sigintEvent):
connectionClosedEvent = asyncio.Event()
while True:
connectionClosedEvent.clear()
try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = SSL_CONTEXTS[self.config.irc.ssl])
try:
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
finally:
self._transport.close() #TODO BaseTransport.close is asynchronous and then triggers the protocol's connection_lost callback; need to wait for connectionClosedEvent again perhaps to correctly handle ^C?
except (ConnectionRefusedError, asyncio.TimeoutError) as e:
logging.error(str(e))
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
if sigintEvent.is_set():
break


class WebServer:
def __init__(self, messageQueue, config):
self.messageQueue = messageQueue
self.config = config

self._paths = {} # '/path' => ('#channel', auth) where auth is either False (no authentication) or the HTTP header value for basic auth

self._app = aiohttp.web.Application()
self._app.add_routes([aiohttp.web.post('/{path:.+}', self.post)])

self.update_config(config)

def update_config(self, config):
self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()}
needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port)
self.config = config
if needRebind:
#TODO
logging.error('Webserver host or port changes while running are currently not supported')

async def run(self, stopEvent):
runner = aiohttp.web.AppRunner(self._app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port)
await site.start()
await stopEvent.wait()
await runner.cleanup()

async def post(self, request):
logging.info(f'Received request for {request.path!r} with data {await request.read()!r}')
try:
channel, auth = self._paths[request.path]
except KeyError:
raise aiohttp.web.HTTPNotFound()
if auth:
authHeader = request.headers.get('Authorization')
if not authHeader or authHeader != auth:
raise aiohttp.web.HTTPForbidden()
try:
data = await request.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
logging.error(f'Invalid data received: {await request.read()!r}')
raise aiohttp.web.HTTPBadRequest()
if 'message' not in data:
logging.error(f'Message missing: {await request.read()!r}')
raise aiohttp.web.HTTPBadRequest()
if '\r' in data['message'] or '\n' in data['message']:
logging.error(f'Linebreaks in message: {await request.read()!r}')
raise aiohttp.web.HTTPBadRequest()
logging.debug(f'Putting message {data["message"]!r} for {channel} into message queue')
self.messageQueue.put_nowait((channel, data['message']))
raise aiohttp.web.HTTPOk()


async def main():
if len(sys.argv) != 2:
print('Usage: web2irc.py CONFIGFILE', file = sys.stderr)
sys.exit(1)
configFile = sys.argv[1]
config = Config(configFile)

loop = asyncio.get_running_loop()

messageQueue = MessageQueue()

irc = IRCClient(messageQueue, config)
webserver = WebServer(messageQueue, config)

sigintEvent = asyncio.Event()
def sigint_callback():
logging.info('Got SIGINT')
nonlocal sigintEvent
sigintEvent.set()
loop.add_signal_handler(signal.SIGINT, sigint_callback)

def sigusr1_callback():
logging.info('Got SIGUSR1, reloading config')
nonlocal config, irc, webserver
newConfig = config.reread()
config = newConfig
irc.update_config(config)
webserver.update_config(config)
loop.add_signal_handler(signal.SIGUSR1, sigusr1_callback)

await asyncio.gather(irc.run(loop, sigintEvent), webserver.run(sigintEvent))


if __name__ == '__main__':
asyncio.run(main())

+ 0
- 246
nodeping2irc.py Ver fichero

@@ -1,246 +0,0 @@
import aiohttp
import aiohttp.web
import argparse
import asyncio
import collections
import concurrent.futures
import json
import logging
import signal


logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')


class MessageQueue:
# An object holding onto the messages received from nodeping
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code.
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that.
# Differences to asyncio.Queue include:
# - No maxsize
# - No put coroutine (not necessary since the queue can never be full)
# - Only one concurrent getter
# - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails)

def __init__(self):
self._getter = None # None | asyncio.Future
self._queue = collections.deque()

async def get(self):
if self._getter is not None:
raise RuntimeError('Cannot get concurrently')
if len(self._queue) == 0:
self._getter = asyncio.get_running_loop().create_future()
logging.debug('Awaiting getter')
try:
await self._getter
except asyncio.CancelledError:
logging.debug('Cancelled getter')
self._getter = None
raise
logging.debug('Awaited getter')
self._getter = None
# For testing the cancellation/putting back onto the queue
#logging.debug('Delaying message queue get')
#await asyncio.sleep(3)
#logging.debug('Done delaying')
return self.get_nowait()

def get_nowait(self):
if len(self._queue) == 0:
raise asyncio.QueueEmpty
return self._queue.popleft()

def put_nowait(self, item):
self._queue.append(item)
if self._getter is not None:
self._getter.set_result(None)

def putleft_nowait(self, item):
self._queue.appendleft(item)
if self._getter is not None:
self._getter.set_result(None)

def qsize(self):
return len(self._queue)


class IRCClientProtocol(asyncio.Protocol):
def __init__(self, messageQueue, stopEvent, loop, nick, real, channel):
logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}')
self.messageQueue = messageQueue
self.stopEvent = stopEvent
self.loop = loop
self.nick = nick
self.real = real
self.channel = channel
self.channelb = channel.encode('utf-8')
self.buffer = b''
self.connected = False

def send(self, data):
logging.info(f'Send: {data!r}')
self.transport.write(data + b'\r\n')

def connection_made(self, transport):
logging.info('Connected')
self.transport = transport
self.connected = True
nickb = self.nick.encode('utf-8')
self.send(b'NICK ' + nickb)
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.real.encode('utf-8'))
self.send(b'JOIN ' + self.channelb)
asyncio.create_task(self.send_messages())

async def _get_message(self):
logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}')
messageFuture = asyncio.create_task(self.messageQueue.get())
done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
if self.stopEvent.is_set():
if messageFuture in pending:
logging.debug('Cancelling messageFuture')
messageFuture.cancel()
try:
await messageFuture
except asyncio.CancelledError:
logging.debug('Cancelled messageFuture')
pass
else:
# messageFuture is already done but we're stopping, so put the result back onto the queue
self.messageQueue.putleft_nowait(messageFuture.result())
return None
assert messageFuture in done, 'Invalid state: messageFuture not in done futures'
return messageFuture.result()

async def send_messages(self):
while self.connected:
logging.debug(f'{id(self)}: trying to get a message')
message = await self._get_message()
logging.debug(f'{id(self)}: got message: {message!r}')
if message is None:
break
self.send(b'PRIVMSG ' + self.channelb + b' :' + message.encode('utf-8'))
#TODO self.messageQueue.putleft_nowait if delivery fails
await asyncio.sleep(1) # Rate limit

def data_received(self, data):
logging.debug(f'Data received: {data!r}')
# Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that.
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer.
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again.
messages = data.split(b'\r\n')
if self.buffer:
self.message_received(self.buffer + messages[0])
messages = messages[1:]
for message in messages[:-1]:
self.message_received(message)
self.buffer = messages[-1]

def message_received(self, message):
logging.info(f'Message received: {message!r}')
if message.startswith(b'PING '):
self.send(b'PONG ' + message[5:])

def connection_lost(self, exc):
logging.info('The server closed the connection')
self.connected = False
self.stopEvent.set()


class WebServer:
def __init__(self, messageQueue, host, port, auth):
self.messageQueue = messageQueue
self.host = host
self.port = port
self.auth = auth
if auth:
self.authHeader = f'Basic {base64.b64encode(auth.encode("utf-8")).decode("utf-8")}'
self._app = aiohttp.web.Application()
self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)])

async def run(self, stopEvent):
runner = aiohttp.web.AppRunner(self._app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.host, self.port)
await site.start()
await stopEvent.wait()
await runner.cleanup()

async def nodeping_post(self, request):
logging.info(f'Received request with data: {await request.read()!r}')
authHeader = request.headers.get('Authorization')
if self.auth and (not authHeader or authHeader != self.authHeader):
return aiohttp.web.HTTPForbidden()
try:
data = await request.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e:
logging.error(f'Received invalid data: {await request.read()!r}')
return aiohttp.web.HTTPBadRequest()
if 'message' not in data:
logging.error(f'Received invalid data: {await request.read()!r}')
return aiohttp.web.HTTPBadRequest()
if '\r' in data['message'] or '\n' in data['message']:
logging.error(f'Received invalid data: {await request.read()!r}')
return aiohttp.web.HTTPBadRequest()
logging.debug(f'Putting to message queue {id(self.messageQueue)}')
self.messageQueue.put_nowait(data['message'])
return aiohttp.web.HTTPOk()


async def run_irc(loop, messageQueue, sigintEvent, host, port, ssl, nick, real, channel):
stopEvent = asyncio.Event()
while True:
stopEvent.clear()
try:
transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop, nick = nick, real = real, channel = channel), host, port, ssl = ssl)
try:
await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
finally:
transport.close()
except (ConnectionRefusedError, asyncio.TimeoutError) as e:
logging.error(str(e))
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
if sigintEvent.is_set():
break


async def run_webserver(loop, messageQueue, sigintEvent, host, port, auth):
server = WebServer(messageQueue, host, port, auth)
await server.run(sigintEvent)


def parse_args():
parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--irchost', type = str, help = 'IRC server hostname', default = 'irc.hackint.org')
parser.add_argument('--ircport', type = int, help = 'IRC server port', default = 6697)
parser.add_argument('--ircssl', choices = ['yes', 'no', 'insecure'], help = 'enable, disable, or use insecure SSL/TLS', default = 'yes')
parser.add_argument('--ircnick', help = 'IRC nickname', default = 'npbot')
parser.add_argument('--ircreal', help = 'IRC realname', default = 'I am a bot.')
parser.add_argument('--ircchannel', help = 'IRC channel to join and post messages', default = '#nodeping')
parser.add_argument('--webhost', type = str, help = 'web server host to bind to', default = '127.0.0.1')
parser.add_argument('--webport', type = int, help = 'web server port to bind to', default = 8080)
parser.add_argument('--webauth', type = str, help = 'basic auth data (user:pass, or None to disable the check)', default = None)
return parser.parse_args()


async def main():
args = parse_args()
ssl = {'yes': True, 'no': False, 'insecure': ssl.SSLContext()}[args.ircssl]

loop = asyncio.get_running_loop()

messageQueue = MessageQueue()
sigintEvent = asyncio.Event()

def sigint_callback():
logging.info('Got SIGINT')
nonlocal sigintEvent
sigintEvent.set()
loop.add_signal_handler(signal.SIGINT, sigint_callback)

irc = run_irc(loop, messageQueue, sigintEvent, host = args.irchost, port = args.ircport, ssl = ssl, nick = args.ircnick, real = args.ircreal, channel = args.ircchannel)
webserver = run_webserver(loop, messageQueue, sigintEvent, host = args.webhost, port = args.webport, auth = args.webauth)
await asyncio.gather(irc, webserver)


asyncio.run(main())

Cargando…
Cancelar
Guardar