From 50a8b79839cef64ad4293043ff99814d14562b86 Mon Sep 17 00:00:00 2001 From: JustAnotherArchivist Date: Tue, 9 Feb 2021 07:09:47 +0000 Subject: [PATCH] Fix memory leak due to asyncio tasks not getting cancelled asyncio.wait doesn't cancel tasks on reaching the timeout, so all those Event.wait() tasks kept accumulating. --- irclog.py | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/irclog.py b/irclog.py index dffaff4..1f483aa 100644 --- a/irclog.py +++ b/irclog.py @@ -68,6 +68,22 @@ def is_valid_pem(path, withCert): return False +async def wait_cancel_pending(aws, paws = None, **kwargs): + '''asyncio.wait but with automatic cancellation of non-completed tasks. Tasks in paws (persistent awaitables) are not automatically cancelled.''' + if paws is None: + paws = set() + tasks = aws | paws + done, pending = await asyncio.wait(tasks, **kwargs) + for task in pending: + if task not in paws: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + return done, pending + + class Config(dict): def __init__(self, filename): super().__init__() @@ -398,7 +414,7 @@ class IRCClientProtocol(asyncio.Protocol): while True: self.logger.debug(f'Trying to get data from send queue') t = asyncio.create_task(self.sendQueue.get()) - done, pending = await asyncio.wait({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) + done, pending = await wait_cancel_pending({t, asyncio.create_task(self.connectionClosedEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) if self.connectionClosedEvent.is_set(): break assert t in done, f'{t!r} is not in {done!r}' @@ -407,7 +423,7 @@ class IRCClientProtocol(asyncio.Protocol): now = time.time() if self.lastSentTime is not None and now - self.lastSentTime < 1: self.logger.debug(f'Rate limited') - await asyncio.wait({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now) + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = self.lastSentTime + 1 - now) if self.connectionClosedEvent.is_set(): break time_ = self._direct_send(data) @@ -635,7 +651,7 @@ class IRCClientProtocol(asyncio.Protocol): self.logger.info('Quitting') self.lastSentTime = 1.67e34 * math.pi * 1e7 # Disable sending any further messages in send_queue self._direct_send(b'QUIT :Bye') - await asyncio.wait({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10) + await wait_cancel_pending({asyncio.create_task(self.connectionClosedEvent.wait())}, timeout = 10) if not self.connectionClosedEvent.is_set(): self.logger.error('Quitting cleanly did not work, closing connection forcefully') # Event will be set implicitly in connection_lost. @@ -693,26 +709,33 @@ class IRCClient: port = self.config['irc']['port'], ssl = self._get_ssl_context(), )) - done, _ = await asyncio.wait({t, asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED, timeout = 30) + # No automatic cancellation of t because it's handled manually below. + done, _ = await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, paws = {t}, return_when = asyncio.FIRST_COMPLETED, timeout = 30) if t not in done: t.cancel() await t # Raises the CancelledError self._transport, self._protocol = t.result() self.logger.debug('Starting send queue processing') - asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent + sendTask = asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent self.logger.debug('Waiting for connection closure or SIGINT') try: - await asyncio.wait({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(connectionClosedEvent.wait()), asyncio.create_task(sigintEvent.wait())}, return_when = asyncio.FIRST_COMPLETED) finally: self.logger.debug(f'Got connection closed {connectionClosedEvent.is_set()} / SIGINT {sigintEvent.is_set()}') if not connectionClosedEvent.is_set(): self.logger.debug('Quitting connection') await self._protocol.quit() + if not sendTask.done(): + sendTask.cancel() + try: + await sendTask + except asyncio.CancelledError: + pass self._transport = None self._protocol = None except (ConnectionRefusedError, ssl.SSLError, asyncio.TimeoutError, asyncio.CancelledError) as e: self.logger.error(f'{type(e).__module__}.{type(e).__name__}: {e!s}') - await asyncio.wait({asyncio.create_task(sigintEvent.wait())}, timeout = 5) + await wait_cancel_pending({asyncio.create_task(sigintEvent.wait())}, timeout = 5) if sigintEvent.is_set(): self.logger.debug('Got SIGINT, putting EOF and breaking') self.messageQueue.put_nowait(messageEOF) @@ -807,7 +830,7 @@ class Storage: async def flush_files(self, flushExitEvent): lastFlushTime = 0 while True: - await asyncio.wait({asyncio.create_task(flushExitEvent.wait())}, timeout = self.config['storage']['flushTime']) + await wait_cancel_pending({asyncio.create_task(flushExitEvent.wait())}, timeout = self.config['storage']['flushTime']) self.logger.debug('Flushing files') flushedFiles = [] for channel, (fn, f, fLastWriteTime) in self.files.items(): @@ -883,7 +906,7 @@ class WebServer: await runner.setup() site = aiohttp.web.TCPSite(runner, self.config['web']['host'], self.config['web']['port']) await site.start() - await asyncio.wait({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED) + await wait_cancel_pending({asyncio.create_task(stopEvent.wait()), asyncio.create_task(self._configChanged.wait())}, return_when = asyncio.FIRST_COMPLETED) await runner.cleanup() if stopEvent.is_set(): break @@ -1140,17 +1163,20 @@ class WebServer: stderrTask = asyncio.create_task(process_stderr()) await asyncio.wait({stdoutTask, stderrTask}, timeout = self.config['web']['search']['maxTime'] if self.config['web']['search']['maxTime'] != 0 else None) # The stream readers may quit before the process is done even on a successful grep. Wait a tiny bit longer for the process to exit. - await asyncio.wait({asyncio.create_task(proc.wait())}, timeout = 0.1) + procTask = asyncio.create_task(proc.wait()) + await asyncio.wait({procTask}, timeout = 0.1) if proc.returncode is None: # Process hasn't finished yet after maxTime. Murder it and wait for it to die. + assert not procTask.done(), 'procTask is done but proc.returncode is None' self.logger.warning(f'Request {id(request)} grep took more than the time limit') proc.kill() - await asyncio.wait({stdoutTask, stderrTask, asyncio.create_task(proc.wait())}, timeout = 1) # This really shouldn't take longer. + await asyncio.wait({stdoutTask, stderrTask, procTask}, timeout = 1) # This really shouldn't take longer. if proc.returncode is None: # Still not done?! Cancel tasks and bail. self.logger.error(f'Request {id(request)} grep did not exit after getting killed!') stdoutTask.cancel() stderrTask.cancel() + procTask.cancel() return aiohttp.web.HTTPInternalServerError() stdout, incomplete = stdoutTask.result() self.logger.info(f'Request {id(request)} grep exited with {proc.returncode} and produced {len(stdout)} bytes (incomplete: {incomplete})')