From efa7cdeefce356dbe1f88c4af83c29dbadaaf792 Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Wed, 10 Aug 2022 10:43:12 +0100 Subject: [PATCH 1/3] feat: unset ready on reconnect --- naff/api/gateway/gateway.py | 5 +++++ naff/api/gateway/state.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/naff/api/gateway/gateway.py b/naff/api/gateway/gateway.py index a04452916..83382d41d 100644 --- a/naff/api/gateway/gateway.py +++ b/naff/api/gateway/gateway.py @@ -263,6 +263,11 @@ async def _identify(self) -> None: f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!" ) + async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None: + self.state.clear_ready() + self._ready.clear() + await super().reconnect(resume=resume, code=code, url=url) + async def _resume_connection(self) -> None: """Send a resume payload to the gateway.""" if self.ws is None: diff --git a/naff/api/gateway/state.py b/naff/api/gateway/state.py index f888d288f..e40d699b5 100644 --- a/naff/api/gateway/state.py +++ b/naff/api/gateway/state.py @@ -91,6 +91,11 @@ async def stop(self) -> None: self.gateway_started.clear() + def clear_ready(self) -> None: + """Clear the ready event.""" + self._shard_ready.clear() + self.client._ready.clear() # noinspection PyProtectedMember + async def _ws_connect(self) -> None: """Connect to the Discord Gateway.""" logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...") From 59fc704155c60f469d0990b697f3430b589ceaeb Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Wed, 10 Aug 2022 19:34:01 +0100 Subject: [PATCH 2/3] feat: support new ws resume url (#594) --- naff/api/gateway/gateway.py | 9 ++++++--- naff/api/gateway/websocket.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/naff/api/gateway/gateway.py b/naff/api/gateway/gateway.py index 83382d41d..52effc601 100644 --- a/naff/api/gateway/gateway.py +++ b/naff/api/gateway/gateway.py @@ -7,7 +7,7 @@ from typing import TypeVar, TYPE_CHECKING from naff.api import events -from naff.client.const import logger +from naff.client.const import logger, MISSING from naff.client.utils.input_utils import OverriddenJson from naff.client.utils.serializer import dict_filter_none from naff.models.discord.enums import Status @@ -66,6 +66,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None: self.session_id = None self.ws_url = state.gateway_url + self.ws_resume_url = MISSING # This lock needs to be held to send something over the gateway, but is also held when # reconnecting. That way there's no race conditions between sending and reconnecting. @@ -192,11 +193,11 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None: case OPCODE.RECONNECT: logger.info("Gateway requested reconnect. Reconnecting...") - return await self.reconnect(resume=True) + return await self.reconnect(resume=True, url=self.ws_resume_url) case OPCODE.INVALIDATE_SESSION: logger.warning("Gateway has invalidated session! Reconnecting...") - return await self.reconnect(resume=data) + return await self.reconnect(resume=data, url=self.ws_resume_url if data else None) case _: return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}") @@ -208,8 +209,10 @@ async def dispatch_event(self, data, seq, event) -> None: self._trace = data.get("_trace", []) self.sequence = seq self.session_id = data["session_id"] + self.ws_resume_url = data["resume_gateway_url"] logger.info(f"Shard {self.shard[0]} has connected to gateway!") logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}") + # todo: future polls, improve guild caching here. run the debugger. you'll see why return self.state.client.dispatch(events.WebsocketReady(data)) case "RESUMED": diff --git a/naff/api/gateway/websocket.py b/naff/api/gateway/websocket.py index cde3341c6..76bee2b2e 100644 --- a/naff/api/gateway/websocket.py +++ b/naff/api/gateway/websocket.py @@ -236,7 +236,7 @@ async def receive(self, force: bool = False) -> str: return msg - async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None: + async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None: async with self._race_lock: self._closed.clear() @@ -246,7 +246,7 @@ async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None: self.ws = None self._zlib = zlib.decompressobj() - self.ws = await self.state.client.http.websocket_connect(self.ws_url) + self.ws = await self.state.client.http.websocket_connect(url or self.ws_url) hello = await self.receive(force=True) self.heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000 From 9b86f7fc7bbd030b29504ebb560d8b37596b6e01 Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Wed, 10 Aug 2022 20:08:57 +0100 Subject: [PATCH 3/3] fix: handle resume-ready guild create events do not view the horrors that lay within, simply know they exist, and know they protect you. A necessary evil to defend the weak and protect the holy. To smite the miscreants whom aim to damage your cache --- naff/client/client.py | 95 ++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/naff/client/client.py b/naff/client/client.py index 8a3e17de4..e427b8177 100644 --- a/naff/client/client.py +++ b/naff/client/client.py @@ -725,47 +725,68 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: expected_guilds = {to_snowflake(guild["id"]) for guild in data["guilds"]} self._user._add_guilds(expected_guilds) - while True: - try: # wait to let guilds cache - await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) - if self.fetch_members: - # ensure all guilds have completed chunking - for guild in self.guilds: - if guild and not guild.chunked.is_set(): - logger.debug(f"Waiting for {guild.id} to chunk") - await guild.chunked.wait() - - except asyncio.TimeoutError: - logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") - break - self._guild_event.clear() - - if len(self.cache.guild_cache) == len(expected_guilds): - # all guilds cached - break - - if self.fetch_members: - # ensure all guilds have completed chunking - for guild in self.guilds: - if guild and not guild.chunked.is_set(): - logger.debug(f"Waiting for {guild.id} to chunk") - await guild.chunked.wait() - - # run any pending startup tasks - if self.async_startup_tasks: - try: - await asyncio.gather(*self.async_startup_tasks) - except Exception as e: - self.dispatch(events.Error("async-extension-loader", e)) - - # cache slash commands if not self._startup: - await self._init_interactions() + while True: + try: # wait to let guilds cache + await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) + if self.fetch_members: + # ensure all guilds have completed chunking + for guild in self.guilds: + if guild and not guild.chunked.is_set(): + logger.debug(f"Waiting for {guild.id} to chunk") + await guild.chunked.wait() + + except asyncio.TimeoutError: + logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") + break + self._guild_event.clear() + + if len(self.cache.guild_cache) == len(expected_guilds): + # all guilds cached + break + + if self.fetch_members: + # ensure all guilds have completed chunking + for guild in self.guilds: + if guild and not guild.chunked.is_set(): + logger.debug(f"Waiting for {guild.id} to chunk") + await guild.chunked.wait() + + # run any pending startup tasks + if self.async_startup_tasks: + try: + await asyncio.gather(*self.async_startup_tasks) + except Exception as e: + self.dispatch(events.Error("async-extension-loader", e)) + + # cache slash commands + if not self._startup: + await self._init_interactions() - self._ready.set() - if not self._startup: self._startup = True self.dispatch(events.Startup()) + + else: + # reconnect ready + ready_guilds = set() + + async def _temp_listener(_event: events.RawGatewayEvent) -> None: + ready_guilds.add(_event.data["id"]) + + listener = Listener.create("_on_raw_guild_create")(_temp_listener) + self.add_listener(listener) + + while True: + try: + await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) + if len(ready_guilds) == len(expected_guilds): + break + except asyncio.TimeoutError: + break + + self.listeners["raw_guild_create"].remove(listener) + + self._ready.set() self.dispatch(events.Ready()) async def login(self, token) -> None: