|
|
@@ -217,6 +217,8 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
self.connectionClosedEvent = connectionClosedEvent |
|
|
|
self.loop = loop |
|
|
|
self.config = config |
|
|
|
self.lastSentTime = None # float timestamp or None; the latter disables the send rate limit |
|
|
|
self.sendQueue = asyncio.Queue() |
|
|
|
self.buffer = b'' |
|
|
|
self.connected = False |
|
|
|
self.channels = channels # Currently joined/supposed-to-be-joined channels; set(str) |
|
|
@@ -316,12 +318,33 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
self._send_join_part(b'JOIN', channelsToJoin) |
|
|
|
|
|
|
|
def send(self, data): |
|
|
|
self.logger.debug(f'Send: {data!r}') |
|
|
|
self.logger.debug(f'Queueing for send: {data!r}') |
|
|
|
if len(data) > 510: |
|
|
|
raise RuntimeError(f'IRC message too long ({len(data)} > 510): {data!r}') |
|
|
|
time_ = time.time() |
|
|
|
self.transport.write(data + b'\r\n') |
|
|
|
self.messageQueue.put_nowait((time_, b'> ' + data, None, None)) |
|
|
|
self.sendQueue.put_nowait(data) |
|
|
|
|
|
|
|
async def send_queue(self): |
|
|
|
while True: |
|
|
|
self.logger.debug(f'Trying to get data from send queue') |
|
|
|
t = asyncio.create_task(self.sendQueue.get()) |
|
|
|
done, pending = await asyncio.wait((t, self.connectionClosedEvent.wait()), return_when = asyncio.FIRST_COMPLETED) |
|
|
|
if self.connectionClosedEvent.is_set(): |
|
|
|
break |
|
|
|
assert t in done, f'{t!r} is not in {done!r}' |
|
|
|
data = t.result() |
|
|
|
self.logger.debug(f'Got {data!r} from send queue') |
|
|
|
now = time.time() |
|
|
|
if self.lastSentTime is not None and now - self.lastSentTime < 1: |
|
|
|
self.logger.debug(f'Rate limited') |
|
|
|
await asyncio.wait((asyncio.sleep(self.lastSentTime + 1 - now), self.connectionClosedEvent.wait()), return_when = asyncio.FIRST_COMPLETED) |
|
|
|
if self.connectionClosedEvent.is_set(): |
|
|
|
break |
|
|
|
self.logger.debug(f'Send: {data!r}') |
|
|
|
time_ = time.time() |
|
|
|
self.transport.write(data + b'\r\n') |
|
|
|
self.messageQueue.put_nowait((time_, b'> ' + data, None, None)) |
|
|
|
if self.lastSentTime is not None: |
|
|
|
self.lastSentTime = time_ |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
time_ = time.time() |
|
|
@@ -410,6 +433,7 @@ class IRCClientProtocol(asyncio.Protocol): |
|
|
|
self.logger.error('IRC connection registered but not authenticated, terminating connection') |
|
|
|
self.transport.close() |
|
|
|
return |
|
|
|
self.lastSentTime = time.time() |
|
|
|
self._send_join_part(b'JOIN', self.channels) |
|
|
|
|
|
|
|
# Bot getting KICKed |
|
|
@@ -579,6 +603,8 @@ class IRCClient: |
|
|
|
try: |
|
|
|
self.logger.debug('Creating IRC connection') |
|
|
|
self._transport, self._protocol = await loop.create_connection(lambda: IRCClientProtocol(self.messageQueue, connectionClosedEvent, loop, self.config, self.channels), self.config['irc']['host'], self.config['irc']['port'], ssl = self._get_ssl_context()) |
|
|
|
self.logger.debug('Starting send queue processing') |
|
|
|
asyncio.create_task(self._protocol.send_queue()) # Quits automatically on connectionClosedEvent |
|
|
|
self.logger.debug('Waiting for connection closure or SIGINT') |
|
|
|
try: |
|
|
|
await asyncio.wait((connectionClosedEvent.wait(), sigintEvent.wait()), return_when = asyncio.FIRST_COMPLETED) |
|
|
|