Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

fix: unset ready on reconnect #595

Merged
merged 3 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger
from naff.client.const import logger, MISSING
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self.session_id = None

self.ws_url = state.gateway_url
self.ws_resume_url = MISSING

# This lock needs to be held to send something over the gateway, but is also held when
# reconnecting. That way there's no race conditions between sending and reconnecting.
Expand Down Expand Up @@ -192,11 +193,11 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True)
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect(resume=data)
return await self.reconnect(resume=data, url=self.ws_resume_url if data else None)

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
Expand All @@ -208,8 +209,10 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
return self.state.client.dispatch(events.WebsocketReady(data))

case "RESUMED":
Expand Down Expand Up @@ -263,6 +266,11 @@ async def _identify(self) -> None:
f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!"
)

async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
self.state.clear_ready()
self._ready.clear()
await super().reconnect(resume=resume, code=code, url=url)

async def _resume_connection(self) -> None:
"""Send a resume payload to the gateway."""
if self.ws is None:
Expand Down
5 changes: 5 additions & 0 deletions naff/api/gateway/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ async def stop(self) -> None:

self.gateway_started.clear()

def clear_ready(self) -> None:
"""Clear the ready event."""
self._shard_ready.clear()
self.client._ready.clear() # noinspection PyProtectedMember

async def _ws_connect(self) -> None:
"""Connect to the Discord Gateway."""
logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...")
Expand Down
4 changes: 2 additions & 2 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def receive(self, force: bool = False) -> str:

return msg

async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
async with self._race_lock:
self._closed.clear()

Expand All @@ -246,7 +246,7 @@ async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
self.ws = None
self._zlib = zlib.decompressobj()

self.ws = await self.state.client.http.websocket_connect(self.ws_url)
self.ws = await self.state.client.http.websocket_connect(url or self.ws_url)

hello = await self.receive(force=True)
self.heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
Expand Down
95 changes: 58 additions & 37 deletions naff/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,47 +725,68 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
expected_guilds = {to_snowflake(guild["id"]) for guild in data["guilds"]}
self._user._add_guilds(expected_guilds)

while True:
try: # wait to let guilds cache
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

except asyncio.TimeoutError:
logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache")
break
self._guild_event.clear()

if len(self.cache.guild_cache) == len(expected_guilds):
# all guilds cached
break

if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

# run any pending startup tasks
if self.async_startup_tasks:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
await self._init_interactions()
while True:
try: # wait to let guilds cache
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

except asyncio.TimeoutError:
logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache")
break
self._guild_event.clear()

if len(self.cache.guild_cache) == len(expected_guilds):
# all guilds cached
break

if self.fetch_members:
# ensure all guilds have completed chunking
for guild in self.guilds:
if guild and not guild.chunked.is_set():
logger.debug(f"Waiting for {guild.id} to chunk")
await guild.chunked.wait()

# run any pending startup tasks
if self.async_startup_tasks:
try:
await asyncio.gather(*self.async_startup_tasks)
except Exception as e:
self.dispatch(events.Error("async-extension-loader", e))

# cache slash commands
if not self._startup:
await self._init_interactions()

self._ready.set()
if not self._startup:
self._startup = True
self.dispatch(events.Startup())

else:
# reconnect ready
ready_guilds = set()

async def _temp_listener(_event: events.RawGatewayEvent) -> None:
ready_guilds.add(_event.data["id"])

listener = Listener.create("_on_raw_guild_create")(_temp_listener)
self.add_listener(listener)

while True:
try:
await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout)
if len(ready_guilds) == len(expected_guilds):
break
except asyncio.TimeoutError:
break

self.listeners["raw_guild_create"].remove(listener)

self._ready.set()
self.dispatch(events.Ready())

async def login(self, token) -> None:
Expand Down