Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

fix: restore ability to resume ++ #619

Merged
merged 1 commit into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions naff/api/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TypeVar, TYPE_CHECKING

from naff.api import events
from naff.client.const import logger, MISSING
from naff.client.const import logger, MISSING, __api_version__
from naff.client.utils.input_utils import OverriddenJson
from naff.client.utils.serializer import dict_filter_none
from naff.models.discord.enums import Status
Expand Down Expand Up @@ -48,7 +48,6 @@ class GatewayClient(WebsocketClient):
Multiple `WebsocketClient` instances can be used to implement same-process sharding.

Attributes:
buffer: A buffer to hold incoming data until its complete
sequence: The sequence of this connection
session_id: The session ID of this connection

Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None:
self._ready = asyncio.Event()
self._close_gateway = asyncio.Event()

# Santity check, it is extremely important that an instance isn't reused.
# Sanity check, it is extremely important that an instance isn't reused.
self._entered = False

async def __aenter__(self: SELF) -> SELF:
Expand Down Expand Up @@ -177,6 +176,7 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
match op:

case OPCODE.HEARTBEAT:
logger.debug("Received heartbeat request from gateway")
return await self.send_heartbeat()

case OPCODE.HEARTBEAT_ACK:
Expand All @@ -192,7 +192,7 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None:
return self._acknowledged.set()

case OPCODE.RECONNECT:
logger.info("Gateway requested reconnect. Reconnecting...")
logger.debug("Gateway requested reconnect. Reconnecting...")
return await self.reconnect(resume=True, url=self.ws_resume_url)

case OPCODE.INVALIDATE_SESSION:
Expand All @@ -209,7 +209,9 @@ async def dispatch_event(self, data, seq, event) -> None:
self._trace = data.get("_trace", [])
self.sequence = seq
self.session_id = data["session_id"]
self.ws_resume_url = data["resume_gateway_url"]
self.ws_resume_url = (
f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream"
)
logger.info(f"Shard {self.shard[0]} has connected to gateway!")
logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}")
# todo: future polls, improve guild caching here. run the debugger. you'll see why
Expand Down Expand Up @@ -287,7 +289,7 @@ async def _resume_connection(self) -> None:
logger.debug(f"{self.shard[0]} is attempting to resume a connection")

async def send_heartbeat(self) -> None:
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, True)
await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True)
logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat")

async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None:
Expand Down
13 changes: 7 additions & 6 deletions naff/api/gateway/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
import random
import time
import zlib
from abc import abstractmethod
Expand Down Expand Up @@ -275,12 +276,12 @@ async def _start_bee_gees(self) -> None:
if self.heartbeat_interval is None:
raise RuntimeError

# try:
# await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
# except asyncio.TimeoutError:
# pass
# else:
# return
try:
await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5))
except asyncio.TimeoutError:
pass
else:
return

logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds")
while not self._kill_bee_gees.is_set():
Expand Down