Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
fix: unset ready on reconnect (#595)
Browse files Browse the repository at this point in the history
* feat: unset ready on reconnect

* feat: support new ws resume url (#594)

* fix: handle resume-ready guild create events

do not view the horrors that lay within, simply know they exist, and know they protect you. A necessary evil to defend the weak and protect the holy. To smite the miscreants whom aim to damage your cache
  • Loading branch information
LordOfPolls authored Aug 14, 2022
1 parent 6faa36b commit 38dfe69
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 42 deletions.
14 changes: 11 additions & 3 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger
from naff.client.const import logger, MISSING
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self.session_id = None

self.ws_url = state.gateway_url
self.ws_resume_url = MISSING

# This lock needs to be held to send something over the gateway, but is also held when
# reconnecting. That way there's no race conditions between sending and reconnecting.
Expand Down Expand Up @@ -192,11 +193,11 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True)
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect(resume=data)
return await self.reconnect(resume=data, url=self.ws_resume_url if data else None)

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
Expand All @@ -208,8 +209,10 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
return self.state.client.dispatch(events.WebsocketReady(data))

case "RESUMED":
Expand Down Expand Up @@ -263,6 +266,11 @@ async def _identify(self) -> None:
f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!"
)

async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
self.state.clear_ready()
self._ready.clear()
await super().reconnect(resume=resume, code=code, url=url)

async def _resume_connection(self) -> None:
"""Send a resume payload to the gateway."""
if self.ws is None:
Expand Down
5 changes: 5 additions & 0 deletions naff/api/gateway/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ async def stop(self) -> None:

self.gateway_started.clear()

def clear_ready(self) -> None:
"""Clear the ready event."""
self._shard_ready.clear()
self.client._ready.clear() # noinspection PyProtectedMember

async def _ws_connect(self) -> None:
"""Connect to the Discord Gateway."""
logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...")
Expand Down
4 changes: 2 additions & 2 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def receive(self, force: bool = False) -> str:

return msg

async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
async with self._race_lock:
self._closed.clear()

Expand All @@ -246,7 +246,7 @@ async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
self.ws = None
self._zlib = zlib.decompressobj()

self.ws = await self.state.client.http.websocket_connect(self.ws_url)
self.ws = await self.state.client.http.websocket_connect(url or self.ws_url)

hello = await self.receive(force=True)
self.heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
Expand Down
95 changes: 58 additions & 37 deletions naff/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,47 +725,68 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
expected_guilds = {to_snowflake(guild["id"]) for guild in data["guilds"]}
self._user._add_guilds(expected_guilds)

while True:
try: # wait to let guilds cache
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

except asyncio.TimeoutError:
logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache")
break
self._guild_event.clear()

if len(self.cache.guild_cache) == len(expected_guilds):
# all guilds cached
break

if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

# run any pending startup tasks
if self.async_startup_tasks:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
await self._init_interactions()
while True:
try: # wait to let guilds cache
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

except asyncio.TimeoutError:
logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache")
break
self._guild_event.clear()

if len(self.cache.guild_cache) == len(expected_guilds):
# all guilds cached
break

if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

# run any pending startup tasks
if self.async_startup_tasks:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
await self._init_interactions()

self._ready.set()
if not self._startup:
self._startup = True
self.dispatch(events.Startup())

else:
# reconnect ready
ready_guilds = set()

async def _temp_listener(_event: events.RawGatewayEvent) -> None:
ready_guilds.add(_event.data["id"])

listener = Listener.create("_on_raw_guild_create")(_temp_listener)
self.add_listener(listener)

while True:
try:
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if len(ready_guilds) == len(expected_guilds):
break
except asyncio.TimeoutError:
break

self.listeners["raw_guild_create"].remove(listener)

self._ready.set()
self.dispatch(events.Ready())

async def login(self, token) -> None:
Expand Down

0 comments on commit 38dfe69

Please sign in to comment.