Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
feat: fix some oversights in the AutoShardClient (#661)
Browse files Browse the repository at this point in the history
* feat: add shard change presence method

* feat: add get shard id method

* docs: expose auto shard client in docs
  • Loading branch information
LordOfPolls authored Oct 14, 2022
1 parent 8ad466c commit aa45a06
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/src/API Reference/.pages
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
nav:
- Client.md
- AutoShardClient.md
- const.md
- errors.md
- API_Communication
Expand Down
41 changes: 39 additions & 2 deletions naff/client/auto_shard_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import naff.api.events as events
from naff.api.gateway.state import ConnectionState
Expand All @@ -12,7 +12,8 @@
to_snowflake,
)
from naff.models.naff.listener import Listener
from ..api.events import ShardConnect
from naff.models.discord import Status, Activity
from naff.api.events import ShardConnect

if TYPE_CHECKING:
from naff.models import Snowflake_Type
Expand Down Expand Up @@ -104,6 +105,18 @@ def get_shards_guild(self, shard_id: int) -> list[Guild]:
"""
return [guild for key, guild in self.cache.guild_cache.items() if ((key >> 22) % self.total_shards) == shard_id]

def get_shard_id(self, guild_id: "Snowflake_Type") -> int:
"""
Get the shard ID for a given guild.
Args:
guild_id: The ID of the guild
Returns:
The shard ID for the guild
"""
return (int(guild_id) >> 22) % self.total_shards

@Listener.create()
async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
"""
Expand Down Expand Up @@ -228,3 +241,27 @@ async def login(self, token) -> None:
self._connection_states: list[ConnectionState] = [
ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards)
]

async def change_presence(
self,
status: Optional[str | Status] = Status.ONLINE,
activity: Optional[str | Activity] = None,
*,
shard_id: int | None = None,
) -> None:
"""
Change the bot's presence.
Args:
status: The status for the bot to be. i.e. online, afk, etc.
activity: The activity for the bot to be displayed as doing.
shard_id: The shard to change the presence on. If not specified, the presence will be changed on all shards.
!!! note
Bots may only be `playing` `streaming` `listening` `watching` or `competing`, other activity types are likely to fail.
"""
if shard_id is None:
await asyncio.gather(*[shard.change_presence(status, activity) for shard in self._connection_states])
else:
await self._connection_states[shard_id].change_presence(status, activity)

0 comments on commit aa45a06

Please sign in to comment.