|
|
@@ -1,17 +1,72 @@ |
|
|
|
import aiohttp |
|
|
|
import aiohttp.web |
|
|
|
import asyncio |
|
|
|
import collections |
|
|
|
import concurrent.futures |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import signal |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level = logging.INFO, format = '{asctime} {levelname} {message}', style = '{') |
|
|
|
logging.basicConfig(level = logging.DEBUG, format = '{asctime} {levelname} {message}', style = '{') |
|
|
|
|
|
|
|
|
|
|
|
class MessageQueue: |
|
|
|
# An object holding onto the messages received from nodeping |
|
|
|
# This is effectively a reimplementation of parts of asyncio.Queue with some specific additional code. |
|
|
|
# Unfortunately, asyncio.Queue's extensibility (_init, _put, and _get methods) is undocumented, so I don't want to rely on that. |
|
|
|
# Differences to asyncio.Queue include: |
|
|
|
# - No maxsize |
|
|
|
# - No put coroutine (not necessary since the queue can never be full) |
|
|
|
# - Only one concurrent getter |
|
|
|
# - putleft_nowait to put to the front of the queue (so that the IRC client can put a message back when delivery fails) |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._getter = None # None | asyncio.Future |
|
|
|
self._queue = collections.deque() |
|
|
|
|
|
|
|
async def get(self): |
|
|
|
if self._getter is not None: |
|
|
|
raise RuntimeError('Cannot get concurrently') |
|
|
|
if len(self._queue) == 0: |
|
|
|
self._getter = asyncio.get_running_loop().create_future() |
|
|
|
logging.debug('Awaiting getter') |
|
|
|
try: |
|
|
|
await self._getter |
|
|
|
except asyncio.CancelledError: |
|
|
|
logging.debug('Cancelled getter') |
|
|
|
self._getter = None |
|
|
|
raise |
|
|
|
logging.debug('Awaited getter') |
|
|
|
self._getter = None |
|
|
|
# For testing the cancellation/putting back onto the queue |
|
|
|
#logging.debug('Delaying message queue get') |
|
|
|
#await asyncio.sleep(3) |
|
|
|
#logging.debug('Done delaying') |
|
|
|
return self.get_nowait() |
|
|
|
|
|
|
|
def get_nowait(self): |
|
|
|
if len(self._queue) == 0: |
|
|
|
raise asyncio.QueueEmpty |
|
|
|
return self._queue.popleft() |
|
|
|
|
|
|
|
def put_nowait(self, item): |
|
|
|
self._queue.append(item) |
|
|
|
if self._getter is not None: |
|
|
|
self._getter.set_result(None) |
|
|
|
|
|
|
|
def putleft_nowait(self, item): |
|
|
|
self._queue.appendleft(item) |
|
|
|
if self._getter is not None: |
|
|
|
self._getter.set_result(None) |
|
|
|
|
|
|
|
def qsize(self): |
|
|
|
return len(self._queue) |
|
|
|
|
|
|
|
|
|
|
|
class IRCClientProtocol(asyncio.Protocol): |
|
|
|
def __init__(self, messageQueue, stopEvent, loop): |
|
|
|
logging.debug(f'Protocol init {id(self)}: {messageQueue} {id(messageQueue)}, {stopEvent}, {loop}') |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.stopEvent = stopEvent |
|
|
|
self.loop = loop |
|
|
@@ -31,10 +86,35 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
self.send(b'JOIN #nodeping') |
|
|
|
asyncio.create_task(self.send_messages()) |
|
|
|
|
|
|
|
async def _get_message(self): |
|
|
|
logging.debug(f'Message queue {id(self.messageQueue)} length: {self.messageQueue.qsize()}') |
|
|
|
messageFuture = asyncio.create_task(self.messageQueue.get()) |
|
|
|
done, pending = await asyncio.wait((messageFuture, self.stopEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
if self.stopEvent.is_set(): |
|
|
|
if messageFuture in pending: |
|
|
|
logging.debug('Cancelling messageFuture') |
|
|
|
messageFuture.cancel() |
|
|
|
try: |
|
|
|
await messageFuture |
|
|
|
except asyncio.CancelledError: |
|
|
|
logging.debug('Cancelled messageFuture') |
|
|
|
pass |
|
|
|
else: |
|
|
|
# messageFuture is already done but we're stopping, so put the result back onto the queue |
|
|
|
self.messageQueue.putleft_nowait(messageFuture.result()) |
|
|
|
return None |
|
|
|
assert messageFuture in done, 'Invalid state: messageFuture not in done futures' |
|
|
|
return messageFuture.result() |
|
|
|
|
|
|
|
async def send_messages(self): |
|
|
|
while self.connected: |
|
|
|
message = await self.messageQueue.get() |
|
|
|
logging.debug(f'{id(self)}: trying to get a message') |
|
|
|
message = await self._get_message() |
|
|
|
logging.debug(f'{id(self)}: got message: {message!r}') |
|
|
|
if message is None: |
|
|
|
break |
|
|
|
self.send(b'PRIVMSG #nodeping :' + message.encode('utf-8')) |
|
|
|
#TODO self.messageQueue.putleft_nowait if delivery fails |
|
|
|
await asyncio.sleep(1) # Rate limit |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
@@ -57,6 +137,7 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
logging.info('The server closed the connection') |
|
|
|
self.connected = False |
|
|
|
self.stopEvent.set() |
|
|
|
|
|
|
|
|
|
|
@@ -90,6 +171,7 @@ class WebServer: |
|
|
|
if '\r' in data['message'] or '\n' in data['message']: |
|
|
|
logging.error(f'Received invalid data: {await request.read()!r}') |
|
|
|
return aiohttp.web.HTTPBadRequest() |
|
|
|
logging.debug(f'Putting to message queue {id(self.messageQueue)}') |
|
|
|
self.messageQueue.put_nowait(data['message']) |
|
|
|
return aiohttp.web.HTTPOk() |
|
|
|
|
|
|
@@ -120,7 +202,7 @@ async def run_webserver(loop, messageQueue, sigintEvent): |
|
|
|
async def main(): |
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
|
|
messageQueue = asyncio.Queue() |
|
|
|
messageQueue = MessageQueue() |
|
|
|
sigintEvent = asyncio.Event() |
|
|
|
|
|
|
|
def sigint_callback(): |
|
|
|