From aa45a066eff9cf0717ad4eb676cf3bec7666aa69 Mon Sep 17 00:00:00 2001 From: LordOfPolls Date: Fri, 14 Oct 2022 06:18:37 +0100 Subject: [PATCH] feat: fix some oversights in the AutoShardClient (#661) * feat: add shard change presence method * feat: add get shard id method * docs: expose auto shard client in docs --- docs/src/API Reference/.pages | 1 + naff/client/auto_shard_client.py | 41 ++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/docs/src/API Reference/.pages b/docs/src/API Reference/.pages index ceae4fe09..b9562903f 100644 --- a/docs/src/API Reference/.pages +++ b/docs/src/API Reference/.pages @@ -1,5 +1,6 @@ nav: - Client.md + - AutoShardClient.md - const.md - errors.md - API_Communication diff --git a/naff/client/auto_shard_client.py b/naff/client/auto_shard_client.py index b43f9606f..2f282a87d 100644 --- a/naff/client/auto_shard_client.py +++ b/naff/client/auto_shard_client.py @@ -1,7 +1,7 @@ import asyncio import time from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import naff.api.events as events from naff.api.gateway.state import ConnectionState @@ -12,7 +12,8 @@ to_snowflake, ) from naff.models.naff.listener import Listener -from ..api.events import ShardConnect +from naff.models.discord import Status, Activity +from naff.api.events import ShardConnect if TYPE_CHECKING: from naff.models import Snowflake_Type @@ -104,6 +105,18 @@ def get_shards_guild(self, shard_id: int) -> list[Guild]: """ return [guild for key, guild in self.cache.guild_cache.items() if ((key >> 22) % self.total_shards) == shard_id] + def get_shard_id(self, guild_id: "Snowflake_Type") -> int: + """ + Get the shard ID for a given guild. + + Args: + guild_id: The ID of the guild + + Returns: + The shard ID for the guild + """ + return (int(guild_id) >> 22) % self.total_shards + @Listener.create() async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: """ @@ -228,3 +241,27 @@ async def login(self, token) -> None: self._connection_states: list[ConnectionState] = [ ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards) ] + + async def change_presence( + self, + status: Optional[str | Status] = Status.ONLINE, + activity: Optional[str | Activity] = None, + *, + shard_id: int | None = None, + ) -> None: + """ + Change the bot's presence. + + Args: + status: The status for the bot to be. i.e. online, afk, etc. + activity: The activity for the bot to be displayed as doing. + shard_id: The shard to change the presence on. If not specified, the presence will be changed on all shards. + + !!! note + Bots may only be `playing` `streaming` `listening` `watching` or `competing`, other activity types are likely to fail. + + """ + if shard_id is None: + await asyncio.gather(*[shard.change_presence(status, activity) for shard in self._connection_states]) + else: + await self._connection_states[shard_id].change_presence(status, activity)