Browse Source

Get rid of the SimpleNamespace for configuration since it complicates config change detection

SimpleNamespace's documentation does not say anything about equality tests, which is why the update_config methods checked the relevant values using a tuple instead. But as more config values are added, this makes the comparisons unnecessarily long. A dict simplifies this. As a side-effect, the constraints of the maps keys being a valid identifier are no longer relevant either.
master
JustAnotherArchivist 4 years ago
parent
commit
859146621a
1 changed files with 20 additions and 31 deletions
  1. +20
    -31
      http2irc.py

+ 20
- 31
http2irc.py View File

@@ -10,7 +10,6 @@ import signal
import ssl import ssl
import sys import sys
import toml import toml
import types




logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')
@@ -23,11 +22,6 @@ class InvalidConfig(Exception):
'''Error in configuration file''' '''Error in configuration file'''




def _mapping_to_namespace(d):
'''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively'''
return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()})


def is_valid_pem(path, withCert): def is_valid_pem(path, withCert):
'''Very basic check whether something looks like a valid PEM certificate''' '''Very basic check whether something looks like a valid PEM certificate'''
try: try:
@@ -51,13 +45,10 @@ def is_valid_pem(path, withCert):
return False return False




class Config:
class Config(dict):
def __init__(self, filename): def __init__(self, filename):
super().__init__()
self._filename = filename self._filename = filename
# Set below:
self.irc = None
self.web = None
self.maps = None


with open(self._filename, 'r') as fp: with open(self._filename, 'r') as fp:
obj = toml.load(fp) obj = toml.load(fp)
@@ -107,9 +98,7 @@ class Config:
raise InvalidConfig('Invalid web port') raise InvalidConfig('Invalid web port')
if 'maps' in obj: if 'maps' in obj:
for key, map_ in obj['maps'].items(): for key, map_ in obj['maps'].items():
# Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace.
#TODO: Support for fancier identifiers (PEP 3131)?
if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '':
if not isinstance(key, str) or not key:
raise InvalidConfig(f'Invalid map key {key!r}') raise InvalidConfig(f'Invalid map key {key!r}')
if not isinstance(map_, collections.abc.Mapping): if not isinstance(map_, collections.abc.Mapping):
raise InvalidConfig(f'Invalid map for {key!r}') raise InvalidConfig(f'Invalid map for {key!r}')
@@ -118,7 +107,7 @@ class Config:
#TODO: Check values #TODO: Check values


# Default values # Default values
self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
finalObj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}


# Fill in default values for the maps # Fill in default values for the maps
for key, map_ in obj['maps'].items(): for key, map_ in obj['maps'].items():
@@ -129,14 +118,14 @@ class Config:
if 'auth' not in map_: if 'auth' not in map_:
map_['auth'] = False map_['auth'] = False


# Merge in what was read from the config file and convert to SimpleNamespace
# Merge in what was read from the config file and set keys on self
for key in ('irc', 'web', 'maps'): for key in ('irc', 'web', 'maps'):
if key in obj: if key in obj:
self._obj[key].update(obj[key])
setattr(self, key, _mapping_to_namespace(self._obj[key]))
finalObj[key].update(obj[key])
self[key] = finalObj[key]


def __repr__(self): def __repr__(self):
return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})'
return f'<Config(irc={self["irc"]!r}, web={self["web"]!r}, maps={self["maps"]!r})>'


def reread(self): def reread(self):
return Config(self._filename) return Config(self._filename)
@@ -212,9 +201,9 @@ class IRCClientProtocol(asyncio.Protocol):
logging.info('Connected') logging.info('Connected')
self.transport = transport self.transport = transport
self.connected = True self.connected = True
nickb = self.config.irc.nick.encode('utf-8')
nickb = self.config['irc']['nick'].encode('utf-8')
self.send(b'NICK ' + nickb) self.send(b'NICK ' + nickb)
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8'))
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8'))


def update_channels(self, channels: set): def update_channels(self, channels: set):
channelsToPart = self.channels - channels channelsToPart = self.channels - channels
@@ -323,28 +312,28 @@ class IRCClient:
def __init__(self, messageQueue, config): def __init__(self, messageQueue, config):
self.messageQueue = messageQueue self.messageQueue = messageQueue
self.config = config self.config = config
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}


self._transport = None self._transport = None
self._protocol = None self._protocol = None


def update_config(self, config): def update_config(self, config):
needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl)
needReconnect = self.config['irc'] != config['irc']
self.config = config self.config = config
if self._transport: # if currently connected: if self._transport: # if currently connected:
if needReconnect: if needReconnect:
self._transport.close() self._transport.close()
else: else:
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
self._protocol.update_channels(self.channels) self._protocol.update_channels(self.channels)


def _get_ssl_context(self): def _get_ssl_context(self):
ctx = SSL_CONTEXTS[self.config.irc.ssl]
if self.config.irc.certfile and self.config.irc.certkeyfile:
ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
if ctx is True: if ctx is True:
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
if isinstance(ctx, ssl.SSLContext): if isinstance(ctx, ssl.SSLContext):
ctx.load_cert_chain(self.config.irc.certfile, keyfile = self.config.irc.certkeyfile)
ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
return ctx return ctx


async def run(self, loop, sigintEvent): async def run(self, loop, sigintEvent):
@@ -352,7 +341,7 @@ class IRCClient:
while True: while True:
connectionClosedEvent.clear() connectionClosedEvent.clear()
try: try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = self._get_ssl_context())
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
try: try:
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
finally: finally:
@@ -377,8 +366,8 @@ class WebServer:
self.update_config(config) self.update_config(config)


def update_config(self, config): def update_config(self, config):
self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()}
needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port)
self._paths = {map_['webpath']: (map_['ircchannel'], f'Basic {base64.b64encode(map_["auth"].encode("utf-8")).decode("utf-8")}' if map_['auth'] else False) for map_ in config['maps'].values()}
needRebind = self.config['web'] != config['web']
self.config = config self.config = config
if needRebind: if needRebind:
#TODO #TODO
@@ -387,7 +376,7 @@ class WebServer:
async def run(self, stopEvent): async def run(self, stopEvent):
runner = aiohttp.web.AppRunner(self._app) runner = aiohttp.web.AppRunner(self._app)
await runner.setup() await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port)
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start() await site.start()
await stopEvent.wait() await stopEvent.wait()
await runner.cleanup() await runner.cleanup()


Loading…
Cancel
Save