Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
feat: support new ws resume url (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
LordOfPolls authored Aug 10, 2022
1 parent efa7cde commit 59fc704
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger
from naff.client.const import logger, MISSING
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self.session_id = None

self.ws_url = state.gateway_url
self.ws_resume_url = MISSING

# This lock needs to be held to send something over the gateway, but is also held when
# reconnecting. That way there's no race conditions between sending and reconnecting.
Expand Down Expand Up @@ -192,11 +193,11 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True)
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
logger.warning("Gateway has invalidated session! Reconnecting...")
return await self.reconnect(resume=data)
return await self.reconnect(resume=data, url=self.ws_resume_url if data else None)

case _:
return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}")
Expand All @@ -208,8 +209,10 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
return self.state.client.dispatch(events.WebsocketReady(data))

case "RESUMED":
Expand Down
4 changes: 2 additions & 2 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def receive(self, force: bool = False) -> str:

return msg

async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
async def reconnect(self, *, resume: bool = False, code: int = 1012, url: str | None = None) -> None:
async with self._race_lock:
self._closed.clear()

Expand All @@ -246,7 +246,7 @@ async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None:
self.ws = None
self._zlib = zlib.decompressobj()

self.ws = await self.state.client.http.websocket_connect(self.ws_url)
self.ws = await self.state.client.http.websocket_connect(url or self.ws_url)

hello = await self.receive(force=True)
self.heartbeat_interval = hello["d"]["heartbeat_interval"] / 1000
Expand Down

0 comments on commit 59fc704

Please sign in to comment.