|
|
@@ -0,0 +1,135 @@ |
|
|
|
import aiohttp |
|
|
|
import aiohttp.web |
|
|
|
import asyncio |
|
|
|
import concurrent.futures |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import signal |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level = logging.INFO, format = '{asctime} {levelname} {message}', style = '{') |
|
|
|
|
|
|
|
|
|
|
|
class IRCClientProtocol(asyncio.Protocol): |
|
|
|
def __init__(self, messageQueue, stopEvent, loop): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self.stopEvent = stopEvent |
|
|
|
self.loop = loop |
|
|
|
self.buffer = b'' |
|
|
|
self.connected = False |
|
|
|
|
|
|
|
def send(self, data): |
|
|
|
logging.info(f'Send: {data!r}') |
|
|
|
self.transport.write(data + b'\r\n') |
|
|
|
|
|
|
|
def connection_made(self, transport): |
|
|
|
logging.info('Connected') |
|
|
|
self.transport = transport |
|
|
|
self.connected = True |
|
|
|
self.send(b'NICK npbot') |
|
|
|
self.send(b'USER npbot npbot npbot :I am a bot.') |
|
|
|
self.send(b'JOIN #nodeping') |
|
|
|
asyncio.create_task(self.send_messages()) |
|
|
|
|
|
|
|
async def send_messages(self): |
|
|
|
while self.connected: |
|
|
|
message = await self.messageQueue.get() |
|
|
|
self.send(b'PRIVMSG #nodeping :' + message.encode('utf-8')) |
|
|
|
await asyncio.sleep(1) # Rate limit |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
logging.debug(f'Data received: {data!r}') |
|
|
|
# Split received data on CRLF. If there's any data left in the buffer, prepend it to the first message and process that. |
|
|
|
# Then, process all messages except the last one (since data might not end on a CRLF) and keep the remainder in the buffer. |
|
|
|
# If data does end with CRLF, all messages will have been processed and the buffer will be empty again. |
|
|
|
messages = data.split(b'\r\n') |
|
|
|
if self.buffer: |
|
|
|
self.message_received(self.buffer + messages[0]) |
|
|
|
messages = messages[1:] |
|
|
|
for message in messages[:-1]: |
|
|
|
self.message_received(message) |
|
|
|
self.buffer = messages[-1] |
|
|
|
|
|
|
|
def message_received(self, message): |
|
|
|
logging.info(f'Message received: {message!r}') |
|
|
|
if message.startswith(b'PING '): |
|
|
|
self.send(b'PONG ' + message[5:]) |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
logging.info('The server closed the connection') |
|
|
|
self.stopEvent.set() |
|
|
|
|
|
|
|
|
|
|
|
class WebServer: |
|
|
|
def __init__(self, messageQueue): |
|
|
|
self.messageQueue = messageQueue |
|
|
|
self._app = aiohttp.web.Application() |
|
|
|
self._app.add_routes([aiohttp.web.post('/nodeping', self.nodeping_post)]) |
|
|
|
|
|
|
|
async def run(self, stopEvent): |
|
|
|
runner = aiohttp.web.AppRunner(self._app) |
|
|
|
await runner.setup() |
|
|
|
site = aiohttp.web.TCPSite(runner, '127.0.0.1', 8080) |
|
|
|
await site.start() |
|
|
|
await stopEvent.wait() |
|
|
|
await runner.cleanup() |
|
|
|
|
|
|
|
async def nodeping_post(self, request): |
|
|
|
logging.info(f'Received request with data: {await request.read()!r}') |
|
|
|
authHeader = request.headers.get('Authorization') |
|
|
|
if not authHeader or authHeader != 'Basic YXJjaGl2ZXRlYW06aXNvbmZpcmU=': #TODO move out of source code |
|
|
|
return aiohttp.web.HTTPForbidden() |
|
|
|
try: |
|
|
|
data = await request.json() |
|
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError) as e: |
|
|
|
logging.error(f'Received invalid data: {await request.read()!r}') |
|
|
|
return aiohttp.web.HTTPBadRequest() |
|
|
|
if 'message' not in data: |
|
|
|
logging.error(f'Received invalid data: {await request.read()!r}') |
|
|
|
return aiohttp.web.HTTPBadRequest() |
|
|
|
if '\r' in data['message'] or '\n' in data['message']: |
|
|
|
logging.error(f'Received invalid data: {await request.read()!r}') |
|
|
|
return aiohttp.web.HTTPBadRequest() |
|
|
|
self.messageQueue.put_nowait(data['message']) |
|
|
|
return aiohttp.web.HTTPOk() |
|
|
|
|
|
|
|
|
|
|
|
async def run_irc(loop, messageQueue, sigintEvent): |
|
|
|
stopEvent = asyncio.Event() |
|
|
|
while True: |
|
|
|
stopEvent.clear() |
|
|
|
try: |
|
|
|
# transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop), 'irc.hackint.org', 6697, ssl = True) |
|
|
|
transport, protocol = await loop.create_connection(lambda: IRCClientProtocol(messageQueue, stopEvent, loop), '127.0.0.1', 8888) |
|
|
|
try: |
|
|
|
await asyncio.wait((stopEvent.wait(), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
finally: |
|
|
|
transport.close() |
|
|
|
except (ConnectionRefusedError, asyncio.TimeoutError) as e: |
|
|
|
logging.error(str(e)) |
|
|
|
await asyncio.wait((asyncio.sleep(5), sigintEvent.wait()), return_when = concurrent.futures.FIRST_COMPLETED) |
|
|
|
if sigintEvent.is_set(): |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
async def run_webserver(loop, messageQueue, sigintEvent): |
|
|
|
server = WebServer(messageQueue) |
|
|
|
await server.run(sigintEvent) |
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
|
|
messageQueue = asyncio.Queue() |
|
|
|
sigintEvent = asyncio.Event() |
|
|
|
|
|
|
|
def sigint_callback(): |
|
|
|
logging.info('Got SIGINT') |
|
|
|
nonlocal sigintEvent |
|
|
|
sigintEvent.set() |
|
|
|
loop.add_signal_handler(signal.SIGINT, sigint_callback) |
|
|
|
|
|
|
|
await asyncio.gather(run_irc(loop, messageQueue, sigintEvent), run_webserver(loop, messageQueue, sigintEvent)) |
|
|
|
|
|
|
|
|
|
|
|
asyncio.run(main()) |