Просмотр исходного кода

Get rid of the SimpleNamespace for configuration since it complicates config change detection

SimpleNamespace's documentation does not say anything about equality tests, which is why the update_config methods checked the relevant values using a tuple instead. But as more config values are added, this makes the comparisons unnecessarily long. A dict simplifies this. As a side-effect, the constraints of the maps keys being a valid identifier are no longer relevant either.
master
JustAnotherArchivist 4 лет назад
Родитель
Сommit
859146621a
1 измененных файлов: 20 добавлений и 31 удалений
  1. +20
    -31
      http2irc.py

+ 20
- 31
http2irc.py Просмотреть файл

@@ -10,7 +10,6 @@ import signal
import ssl
import sys
import toml
import types


logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{')
@@ -23,11 +22,6 @@ class InvalidConfig(Exception):
'''Error in configuration file'''


def _mapping_to_namespace(d):
'''Converts a mapping (e.g. dict) to a types.SimpleNamespace, recursively'''
return types.SimpleNamespace(**{key: _mapping_to_namespace(value) if isinstance(value, collections.abc.Mapping) else value for key, value in d.items()})


def is_valid_pem(path, withCert):
'''Very basic check whether something looks like a valid PEM certificate'''
try:
@@ -51,13 +45,10 @@ def is_valid_pem(path, withCert):
return False


class Config:
class Config(dict):
def __init__(self, filename):
super().__init__()
self._filename = filename
# Set below:
self.irc = None
self.web = None
self.maps = None

with open(self._filename, 'r') as fp:
obj = toml.load(fp)
@@ -107,9 +98,7 @@ class Config:
raise InvalidConfig('Invalid web port')
if 'maps' in obj:
for key, map_ in obj['maps'].items():
# Ensure that the key is a valid Python identifier since it will be set as an attribute in the namespace.
#TODO: Support for fancier identifiers (PEP 3131)?
if not isinstance(key, str) or not key or key.strip('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_') != '' or key[0].strip('0123456789') == '':
if not isinstance(key, str) or not key:
raise InvalidConfig(f'Invalid map key {key!r}')
if not isinstance(map_, collections.abc.Mapping):
raise InvalidConfig(f'Invalid map for {key!r}')
@@ -118,7 +107,7 @@ class Config:
#TODO: Check values

# Default values
self._obj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}
finalObj = {'irc': {'host': 'irc.hackint.org', 'port': 6697, 'ssl': 'yes', 'nick': 'h2ibot', 'real': 'I am an http2irc bot.', 'certfile': None, 'certkeyfile': None}, 'web': {'host': '127.0.0.1', 'port': 8080}, 'maps': {}}

# Fill in default values for the maps
for key, map_ in obj['maps'].items():
@@ -129,14 +118,14 @@ class Config:
if 'auth' not in map_:
map_['auth'] = False

# Merge in what was read from the config file and convert to SimpleNamespace
# Merge in what was read from the config file and set keys on self
for key in ('irc', 'web', 'maps'):
if key in obj:
self._obj[key].update(obj[key])
setattr(self, key, _mapping_to_namespace(self._obj[key]))
finalObj[key].update(obj[key])
self[key] = finalObj[key]

def __repr__(self):
return f'Config(irc={self.irc!r}, web={self.web!r}, maps={self.maps!r})'
return f'<Config(irc={self["irc"]!r}, web={self["web"]!r}, maps={self["maps"]!r})>'

def reread(self):
return Config(self._filename)
@@ -212,9 +201,9 @@ class IRCClientProtocol(asyncio.Protocol):
logging.info('Connected')
self.transport = transport
self.connected = True
nickb = self.config.irc.nick.encode('utf-8')
nickb = self.config['irc']['nick'].encode('utf-8')
self.send(b'NICK ' + nickb)
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config.irc.real.encode('utf-8'))
self.send(b'USER ' + nickb + b' ' + nickb + b' ' + nickb + b' :' + self.config['irc']['real'].encode('utf-8'))

def update_channels(self, channels: set):
channelsToPart = self.channels - channels
@@ -323,28 +312,28 @@ class IRCClient:
def __init__(self, messageQueue, config):
self.messageQueue = messageQueue
self.config = config
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}

self._transport = None
self._protocol = None

def update_config(self, config):
needReconnect = (self.config.irc.host, self.config.irc.port, self.config.irc.ssl) != (config.irc.host, config.irc.port, config.irc.ssl)
needReconnect = self.config['irc'] != config['irc']
self.config = config
if self._transport: # if currently connected:
if needReconnect:
self._transport.close()
else:
self.channels = {map_.ircchannel for map_ in config.maps.__dict__.values()}
self.channels = {map_['ircchannel'] for map_ in config['maps'].values()}
self._protocol.update_channels(self.channels)

def _get_ssl_context(self):
ctx = SSL_CONTEXTS[self.config.irc.ssl]
if self.config.irc.certfile and self.config.irc.certkeyfile:
ctx = SSL_CONTEXTS[self.config['irc']['ssl']]
if self.config['irc']['certfile'] and self.config['irc']['certkeyfile']:
if ctx is True:
ctx = ssl.create_default_context()
if isinstance(ctx, ssl.SSLContext):
ctx.load_cert_chain(self.config.irc.certfile, keyfile = self.config.irc.certkeyfile)
ctx.load_cert_chain(self.config['irc']['certfile'], keyfile = self.config['irc']['certkeyfile'])
return ctx

async def run(self, loop, sigintEvent):
@@ -352,7 +341,7 @@ class IRCClient:
while True:
connectionClosedEvent.clear()
try:
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config.irc.host, self.config.irc.port, ssl = self._get_ssl_context())
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context())
try:
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED)
finally:
@@ -377,8 +366,8 @@ class WebServer:
self.update_config(config)

def update_config(self, config):
self._paths = {map_.webpath: (map_.ircchannel, f'Basic {base64.b64encode(map_.auth.encode("utf-8")).decode("utf-8")}' if map_.auth else False) for map_ in config.maps.__dict__.values()}
needRebind = (self.config.web.host, self.config.web.port) != (config.web.host, config.web.port)
self._paths = {map_['webpath']: (map_['ircchannel'], f'Basic {base64.b64encode(map_["auth"].encode("utf-8")).decode("utf-8")}' if map_['auth'] else False) for map_ in config['maps'].values()}
needRebind = self.config['web'] != config['web']
self.config = config
if needRebind:
#TODO
@@ -387,7 +376,7 @@ class WebServer:
async def run(self, stopEvent):
runner = aiohttp.web.AppRunner(self._app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, self.config.web.host, self.config.web.port)
site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port'])
await site.start()
await stopEvent.wait()
await runner.cleanup()


Загрузка…
Отмена
Сохранить