From e325f87cce63f0c9a6222e3d9bd87b2bba6d5af9 Mon Sep 17 00:00:00 2001 From: Eneg <42005170+Enegg@users.noreply.github.com> Date: Fri, 15 Nov 2024 06:49:37 +0100 Subject: [PATCH] Typed IDs --- disnake/abc.py | 82 ++++----- disnake/app_commands.py | 15 +- disnake/asset.py | 13 +- disnake/automod.py | 33 ++-- disnake/channel.py | 254 ++++++++++++++-------------- disnake/client.py | 279 ++++++++++++++++++++----------- disnake/colour.py | 4 +- disnake/embeds.py | 4 +- disnake/emoji.py | 25 +-- disnake/guild.py | 247 ++++++++++++++++----------- disnake/guild_preview.py | 3 +- disnake/guild_scheduled_event.py | 30 ++-- disnake/invite.py | 29 ++-- disnake/member.py | 39 ++--- disnake/mentions.py | 5 +- disnake/message.py | 85 +++++----- disnake/mixins.py | 63 ++++++- disnake/object.py | 14 +- disnake/partial_emoji.py | 19 ++- disnake/reaction.py | 3 +- disnake/role.py | 3 +- disnake/state.py | 195 ++++++++++++--------- disnake/sticker.py | 5 +- disnake/threads.py | 47 ++++-- disnake/types/ids.py | 92 ++++++++++ disnake/user.py | 11 +- disnake/utils.py | 23 +-- disnake/voice_client.py | 3 +- disnake/webhook/async_.py | 57 ++++--- disnake/webhook/sync.py | 29 ++-- disnake/widget.py | 5 +- 31 files changed, 1059 insertions(+), 657 deletions(-) create mode 100644 disnake/types/ids.py diff --git a/disnake/abc.py b/disnake/abc.py index c6c4c651cf..d5d8fba26d 100644 --- a/disnake/abc.py +++ b/disnake/abc.py @@ -43,6 +43,15 @@ from .permissions import PermissionOverwrite, Permissions from .role import Role from .sticker import GuildSticker, StandardSticker, StickerItem +from .types.ids import ( + CategoryId, + ChannelId, + GuildId, + MessageId, + PrivateChannelId, + UserId, + overload_fetch, +) from .ui.action_row import components_to_dict from .utils import _overload_with_permissions from .voice_client import VoiceClient, VoiceProtocol @@ -74,6 +83,7 @@ from .iterators import HistoryIterator from .member import Member from .message import Message, MessageReference, PartialMessage + from .mixins import IdT from .poll import Poll from .state import ConnectionState from .threads import AnyThreadArchiveDuration, ForumTag @@ -100,7 +110,7 @@ @runtime_checkable -class Snowflake(Protocol): +class Snowflake(Protocol["IdT"]): """An ABC that details the common operations on a Discord model. Almost all :ref:`Discord models ` meet this @@ -116,11 +126,13 @@ class Snowflake(Protocol): """ __slots__ = () - id: int + + @property + def id(self) -> IdT: ... @runtime_checkable -class User(Snowflake, Protocol): +class User(Snowflake[UserId], Protocol): """An ABC that details the common operations on a Discord user. The following classes implement this ABC: @@ -180,7 +192,7 @@ def avatar(self) -> Optional[Asset]: @runtime_checkable -class PrivateChannel(Snowflake, Protocol): +class PrivateChannel(Snowflake[PrivateChannelId], Protocol): """An ABC that details the common operations on a private Discord channel. The following classes implement this ABC: @@ -255,12 +267,12 @@ class GuildChannel(ABC): __slots__ = () - id: int + id: ChannelId name: str guild: Guild type: ChannelType position: int - category_id: Optional[int] + category_id: Optional[CategoryId] _flags: int _state: ConnectionState _overwrites: List[_Overwrites] @@ -269,8 +281,7 @@ class GuildChannel(ABC): def __init__( self, *, state: ConnectionState, guild: Guild, data: Mapping[str, Any] - ) -> None: - ... + ) -> None: ... def __str__(self) -> str: return self.name @@ -285,7 +296,7 @@ def _update(self, guild: Guild, data: Dict[str, Any]) -> None: async def _move( self, position: int, - parent_id: Optional[int] = None, + parent_id: Optional[CategoryId] = None, lock_permissions: bool = False, *, reason: Optional[str], @@ -330,7 +341,7 @@ async def _edit( position: int = MISSING, nsfw: bool = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -347,7 +358,7 @@ async def _edit( default_layout: ThreadLayout = MISSING, reason: Optional[str] = None, ) -> Optional[ChannelPayload]: - parent_id: Optional[int] + parent_id: Optional[CategoryId] if category is not MISSING: # if category is given, it's either `None` (no parent) or a category channel parent_id = category.id if category else None @@ -793,7 +804,7 @@ def permissions_for( # Apply channel specific role permission overwrites for overwrite in remaining_overwrites: - if overwrite.is_role() and roles.has(overwrite.id): + if overwrite.is_role() and roles.has(overwrite.id): # type: ignore denies |= overwrite.deny allows |= overwrite.allow @@ -843,8 +854,7 @@ async def set_permissions( *, overwrite: Optional[PermissionOverwrite] = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload @_overload_with_permissions @@ -911,8 +921,7 @@ async def set_permissions( view_channel: Optional[bool] = ..., view_creator_monetization_analytics: Optional[bool] = ..., view_guild_insights: Optional[bool] = ..., - ) -> None: - ... + ) -> None: ... async def set_permissions( self, @@ -1030,7 +1039,7 @@ async def _clone_impl( base_attrs: Dict[str, Any], *, name: Optional[str] = None, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING, reason: Optional[str] = None, ) -> Self: @@ -1115,11 +1124,10 @@ async def move( *, beginning: bool, offset: int = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -1127,11 +1135,10 @@ async def move( *, end: bool, offset: int = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -1139,11 +1146,10 @@ async def move( *, before: Snowflake, offset: int = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -1151,11 +1157,10 @@ async def move( *, after: Snowflake, offset: int = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... async def move(self, **kwargs: Any) -> None: """|coro| @@ -1442,8 +1447,7 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def send( @@ -1464,8 +1468,7 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def send( @@ -1486,8 +1489,7 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def send( @@ -1508,8 +1510,7 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - ) -> Message: - ... + ) -> Message: ... async def send( self, @@ -1817,7 +1818,8 @@ def typing(self) -> Typing: """ return Typing(self) - async def fetch_message(self, id: int, /) -> Message: + @overload_fetch + async def fetch_message(self, id: MessageId, /) -> Message: """|coro| Retrieves a single :class:`.Message` from the destination. @@ -1961,9 +1963,9 @@ class Connectable(Protocol): __slots__ = () _state: ConnectionState guild: Guild - id: int + id: ChannelId - def _get_voice_client_key(self) -> Tuple[int, str]: + def _get_voice_client_key(self) -> Tuple[GuildId, str]: raise NotImplementedError def _get_voice_state_pair(self) -> Tuple[int, int]: diff --git a/disnake/app_commands.py b/disnake/app_commands.py index 727f35cb93..c938eca1e8 100644 --- a/disnake/app_commands.py +++ b/disnake/app_commands.py @@ -19,6 +19,7 @@ ) from .i18n import Localized from .permissions import Permissions +from .types.ids import ApplicationCommandId, ApplicationId, GuildId from .utils import MISSING, _get_as_snowflake, _maybe_cast if TYPE_CHECKING: @@ -608,9 +609,9 @@ class _APIApplicationCommandMixin: __repr_info__ = ("id",) def _update_common(self, data: ApplicationCommandPayload) -> None: - self.id: int = int(data["id"]) - self.application_id: int = int(data["application_id"]) - self.guild_id: Optional[int] = _get_as_snowflake(data, "guild_id") + self.id: ApplicationCommandId = ApplicationCommandId(int(data["id"])) + self.application_id: ApplicationId = ApplicationId(int(data["application_id"])) + self.guild_id: Optional[GuildId] = _get_as_snowflake(data, "guild_id", GuildId) self.version: int = int(data["version"]) # deprecated, but kept until API stops returning this field self._default_permission = data.get("default_permission") is not False @@ -1016,13 +1017,13 @@ class ApplicationCommandPermissions: __slots__ = ("id", "type", "permission", "_guild_id") - def __init__(self, *, data: ApplicationCommandPermissionsPayload, guild_id: int) -> None: + def __init__(self, *, data: ApplicationCommandPermissionsPayload, guild_id: GuildId) -> None: self.id: int = int(data["id"]) self.type: ApplicationCommandPermissionType = try_enum( ApplicationCommandPermissionType, data["type"] ) self.permission: bool = data["permission"] - self._guild_id: int = guild_id + self._guild_id: GuildId = guild_id def __repr__(self) -> str: return f"" @@ -1079,8 +1080,8 @@ def __init__( ) -> None: self._state: ConnectionState = state self.id: int = int(data["id"]) - self.application_id: int = int(data["application_id"]) - self.guild_id: int = int(data["guild_id"]) + self.application_id: ApplicationId = ApplicationId(int(data["application_id"])) + self.guild_id: GuildId = GuildId(int(data["guild_id"])) self.permissions: List[ApplicationCommandPermissions] = [ ApplicationCommandPermissions(data=elem, guild_id=self.guild_id) diff --git a/disnake/asset.py b/disnake/asset.py index edb0d1c7a6..8e61ee69c5 100644 --- a/disnake/asset.py +++ b/disnake/asset.py @@ -11,6 +11,7 @@ from . import utils from .errors import DiscordException from .file import File +from .types.ids import GuildId, MemberId, RoleId, UserId, WebhookId __all__ = ("Asset",) @@ -214,7 +215,7 @@ def _from_default_avatar(cls, state: AnyState, index: int) -> Self: ) @classmethod - def _from_avatar(cls, state: AnyState, user_id: int, avatar: str) -> Self: + def _from_avatar(cls, state: AnyState, user_id: Union[UserId, WebhookId], avatar: str) -> Self: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -226,7 +227,7 @@ def _from_avatar(cls, state: AnyState, user_id: int, avatar: str) -> Self: @classmethod def _from_guild_avatar( - cls, state: AnyState, guild_id: int, member_id: int, avatar: str + cls, state: AnyState, guild_id: GuildId, member_id: MemberId, avatar: str ) -> Self: animated = avatar.startswith("a_") format = "gif" if animated else "png" @@ -256,7 +257,7 @@ def _from_cover_image(cls, state: AnyState, object_id: int, cover_image_hash: st ) @classmethod - def _from_guild_image(cls, state: AnyState, guild_id: int, image: str, path: str) -> Self: + def _from_guild_image(cls, state: AnyState, guild_id: GuildId, image: str, path: str) -> Self: return cls( state, url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024", @@ -265,7 +266,7 @@ def _from_guild_image(cls, state: AnyState, guild_id: int, image: str, path: str ) @classmethod - def _from_guild_icon(cls, state: AnyState, guild_id: int, icon_hash: str) -> Self: + def _from_guild_icon(cls, state: AnyState, guild_id: GuildId, icon_hash: str) -> Self: animated = icon_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -296,7 +297,7 @@ def _from_banner(cls, state: AnyState, id: int, banner_hash: str) -> Self: ) @classmethod - def _from_role_icon(cls, state: AnyState, role_id: int, icon_hash: str) -> Self: + def _from_role_icon(cls, state: AnyState, role_id: RoleId, icon_hash: str) -> Self: animated = icon_hash.startswith("a_") format = "gif" if animated else "png" return cls( @@ -337,7 +338,7 @@ def __repr__(self) -> str: shorten = self._url.replace(self.BASE, "") return f"" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, Asset) and self._url == other._url def __hash__(self) -> int: diff --git a/disnake/automod.py b/disnake/automod.py index f07fe6f95c..d4544f58da 100644 --- a/disnake/automod.py +++ b/disnake/automod.py @@ -25,6 +25,7 @@ try_enum_to_int, ) from .flags import AutoModKeywordPresets +from .types.ids import CategoryId, ChannelId, ChannelOrThreadId, MemberId, MessageId, RoleId from .utils import MISSING, _get_as_snowflake, snowflake_time if TYPE_CHECKING: @@ -172,7 +173,7 @@ class AutoModSendAlertAction(AutoModAction): _metadata: AutoModSendAlertActionMetadata - def __init__(self, channel: Snowflake) -> None: + def __init__(self, channel: Snowflake[ChannelId]) -> None: super().__init__(type=AutoModActionType.send_alert_message) self._metadata["channel_id"] = channel.id @@ -296,8 +297,7 @@ def __init__( keyword_filter: Optional[Sequence[str]], regex_patterns: Optional[Sequence[str]] = None, allow_list: Optional[Sequence[str]] = None, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -306,8 +306,7 @@ def __init__( keyword_filter: Optional[Sequence[str]] = None, regex_patterns: Optional[Sequence[str]], allow_list: Optional[Sequence[str]] = None, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -315,14 +314,12 @@ def __init__( *, presets: AutoModKeywordPresets, allow_list: Optional[Sequence[str]] = None, - ) -> None: - ... + ) -> None: ... @overload def __init__( self, *, mention_total_limit: int, mention_raid_protection_enabled: bool = False - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -472,7 +469,7 @@ def __init__(self, *, data: AutoModRulePayload, guild: Guild) -> None: self.id: int = int(data["id"]) self.name: str = data["name"] self.enabled: bool = data["enabled"] - self.creator_id: int = int(data["creator_id"]) + self.creator_id: MemberId = MemberId(int(data["creator_id"])) self.event_type: AutoModEventType = try_enum(AutoModEventType, data["event_type"]) self.trigger_type: AutoModTriggerType = try_enum(AutoModTriggerType, data["trigger_type"]) self._actions: List[AutoModAction] = [ @@ -540,8 +537,8 @@ async def edit( trigger_metadata: AutoModTriggerMetadata = MISSING, actions: Sequence[AutoModAction] = MISSING, enabled: bool = MISSING, - exempt_roles: Optional[Iterable[Snowflake]] = MISSING, - exempt_channels: Optional[Iterable[Snowflake]] = MISSING, + exempt_roles: Optional[Iterable[Snowflake[RoleId]]] = MISSING, + exempt_channels: Optional[Iterable[Snowflake[Union[ChannelId, CategoryId]]]] = MISSING, reason: Optional[str] = None, ) -> AutoModRule: """|coro| @@ -747,10 +744,14 @@ def __init__(self, *, data: AutoModerationActionExecutionEvent, guild: Guild) -> self.rule_trigger_type: AutoModTriggerType = try_enum( AutoModTriggerType, data["rule_trigger_type"] ) - self.user_id: int = int(data["user_id"]) - self.channel_id: Optional[int] = _get_as_snowflake(data, "channel_id") - self.message_id: Optional[int] = _get_as_snowflake(data, "message_id") - self.alert_message_id: Optional[int] = _get_as_snowflake(data, "alert_system_message_id") + self.user_id: MemberId = MemberId(int(data["user_id"])) + self.channel_id: Optional[ChannelOrThreadId] = _get_as_snowflake( + data, "channel_id", ChannelOrThreadId + ) + self.message_id: Optional[MessageId] = _get_as_snowflake(data, "message_id", MessageId) + self.alert_message_id: Optional[MessageId] = _get_as_snowflake( + data, "alert_system_message_id", MessageId + ) self.content: str = data.get("content") or "" self.matched_keyword: Optional[str] = data.get("matched_keyword") self.matched_content: Optional[str] = data.get("matched_content") diff --git a/disnake/channel.py b/disnake/channel.py index a1a37a057f..7101bdc428 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -48,6 +48,18 @@ from .permissions import PermissionOverwrite, Permissions from .stage_instance import StageInstance from .threads import ForumTag, Thread +from .types.ids import ( + CategoryId, + ChannelId, + EmojiId, + GuildId, + MemberId, + MessageId, + PrivateChannelId, + ThreadId, + UserId, + overload_get, +) from .utils import MISSING __all__ = ( @@ -150,7 +162,7 @@ async def _single_delete_strategy(messages: Iterable[Message]) -> None: await m.delete() -class TextChannel(disnake.abc.Messageable, disnake.abc.GuildChannel, Hashable): +class TextChannel(disnake.abc.Messageable, disnake.abc.GuildChannel, Hashable[ChannelId]): """Represents a Discord guild text channel. .. collapse:: operations @@ -247,7 +259,7 @@ class TextChannel(disnake.abc.Messageable, disnake.abc.GuildChannel, Hashable): def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id = ChannelId(int(data["id"])) self._type: Literal[0, 5] = data["type"] self._update(guild, data) @@ -269,7 +281,9 @@ def _update(self, guild: Guild, data: TextChannelPayload) -> None: self.guild: Guild = guild # apparently this can be nullable in the case of a bad api deploy self.name: str = data.get("name") or "" - self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.category_id: Optional[CategoryId] = utils._get_as_snowflake( + data, "parent_id", CategoryId + ) self.topic: Optional[str] = data.get("topic") self.position: int = data["position"] self._flags = data.get("flags", 0) @@ -281,7 +295,9 @@ def _update(self, guild: Guild, data: TextChannelPayload) -> None: "default_auto_archive_duration", 1440 ) self._type: Literal[0, 5] = data.get("type", self._type) - self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: Optional[MessageId] = utils._get_as_snowflake( + data, "last_message_id", MessageId + ) self.last_pin_timestamp: Optional[datetime.datetime] = utils.parse_time( data.get("last_pin_timestamp") ) @@ -374,11 +390,10 @@ async def edit( self, *, position: int, - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... # only passing `sync_permissions` may or may not return a channel, # depending on whether the channel is in a category @@ -388,8 +403,7 @@ async def edit( *, sync_permissions: bool, reason: Optional[str] = ..., - ) -> Optional[TextChannel]: - ... + ) -> Optional[TextChannel]: ... @overload async def edit( @@ -400,7 +414,7 @@ async def edit( position: int = ..., nsfw: bool = ..., sync_permissions: bool = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., slowmode_delay: Optional[int] = ..., default_thread_slowmode_delay: Optional[int] = ..., default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = ..., @@ -408,8 +422,7 @@ async def edit( overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., flags: ChannelFlags = ..., reason: Optional[str] = ..., - ) -> TextChannel: - ... + ) -> TextChannel: ... async def edit( self, @@ -419,7 +432,7 @@ async def edit( position: int = MISSING, nsfw: bool = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -536,7 +549,7 @@ async def clone( topic: Optional[str] = MISSING, position: int = MISSING, nsfw: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: int = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: AnyThreadArchiveDuration = MISSING, @@ -637,7 +650,7 @@ async def clone( overwrites=overwrites, ) - async def delete_messages(self, messages: Iterable[Snowflake]) -> None: + async def delete_messages(self, messages: Iterable[Snowflake[MessageId]]) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -676,7 +689,7 @@ async def delete_messages(self, messages: Iterable[Snowflake]) -> None: return # do nothing if len(messages) == 1: - message_id: int = messages[0].id + message_id = messages[0].id await self._state.http.delete_message(self.id, message_id) return @@ -926,7 +939,8 @@ async def follow(self, *, destination: TextChannel, reason: Optional[str] = None ) return Webhook._as_follower(data, channel=destination, user=self._state.user) - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the given message ID. This is useful if you want to work with a message and only have its ID without @@ -948,7 +962,8 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: return PartialMessage(channel=self, id=message_id) - def get_thread(self, thread_id: int, /) -> Optional[Thread]: + @overload_get + def get_thread(self, thread_id: ThreadId, /) -> Optional[Thread]: """Returns a thread with the given ID. .. versionadded:: 2.0 @@ -970,12 +985,11 @@ async def create_thread( self, *, name: str, - message: Snowflake, + message: Snowflake[MessageId], auto_archive_duration: Optional[AnyThreadArchiveDuration] = None, slowmode_delay: Optional[int] = None, reason: Optional[str] = None, - ) -> Thread: - ... + ) -> Thread: ... @overload async def create_thread( @@ -987,14 +1001,13 @@ async def create_thread( invitable: Optional[bool] = None, slowmode_delay: Optional[int] = None, reason: Optional[str] = None, - ) -> Thread: - ... + ) -> Thread: ... async def create_thread( self, *, name: str, - message: Optional[Snowflake] = None, + message: Optional[Snowflake[MessageId]] = None, auto_archive_duration: Optional[AnyThreadArchiveDuration] = None, type: Optional[ThreadType] = None, invitable: Optional[bool] = None, @@ -1147,7 +1160,7 @@ def archived_threads( ) -class VocalGuildChannel(disnake.abc.Connectable, disnake.abc.GuildChannel, Hashable): +class VocalGuildChannel(disnake.abc.Connectable, disnake.abc.GuildChannel, Hashable[ChannelId]): __slots__ = ( "name", "id", @@ -1171,13 +1184,13 @@ def __init__( data: Union[VoiceChannelPayload, StageChannelPayload], ) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id = ChannelId(int(data["id"])) self._update(guild, data) - def _get_voice_client_key(self) -> Tuple[int, str]: + def _get_voice_client_key(self) -> Tuple[GuildId, str]: return self.guild.id, "guild_id" - def _get_voice_state_pair(self) -> Tuple[int, int]: + def _get_voice_state_pair(self) -> Tuple[GuildId, int]: return self.guild.id, self.id def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: @@ -1190,7 +1203,9 @@ def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPay VideoQualityMode, data.get("video_quality_mode", 1) ) self._flags = data.get("flags", 0) - self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.category_id: Optional[CategoryId] = utils._get_as_snowflake( + data, "parent_id", CategoryId + ) self.position: int = data["position"] # these don't exist in partial channel objects of slash command options self.bitrate: int = data.get("bitrate", 0) @@ -1213,7 +1228,7 @@ def members(self) -> List[Member]: return ret @property - def voice_states(self) -> Dict[int, VoiceState]: + def voice_states(self) -> Dict[MemberId, VoiceState]: """Returns a mapping of member IDs who have voice states in this channel. .. versionadded:: 1.3 @@ -1364,7 +1379,9 @@ def _update(self, guild: Guild, data: VoiceChannelPayload) -> None: super()._update(guild, data) self.nsfw: bool = data.get("nsfw", False) self.slowmode_delay: int = data.get("rate_limit_per_user", 0) - self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: Optional[MessageId] = utils._get_as_snowflake( + data, "last_message_id", MessageId + ) async def _get_channel(self): return self @@ -1384,7 +1401,7 @@ async def clone( bitrate: int = MISSING, user_limit: int = MISSING, position: int = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, rtc_region: Optional[Union[str, VoiceRegion]] = MISSING, video_quality_mode: VideoQualityMode = MISSING, nsfw: bool = MISSING, @@ -1501,7 +1518,8 @@ def last_message(self) -> Optional[Message]: """ return self._state._get_message(self.last_message_id) if self.last_message_id else None - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the given message ID. This is useful if you want to work with a message and only have its ID without @@ -1529,11 +1547,10 @@ async def edit( self, *, position: int, - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... # only passing `sync_permissions` may or may not return a channel, # depending on whether the channel is in a category @@ -1543,8 +1560,7 @@ async def edit( *, sync_permissions: bool, reason: Optional[str] = ..., - ) -> Optional[VoiceChannel]: - ... + ) -> Optional[VoiceChannel]: ... @overload async def edit( @@ -1555,7 +1571,7 @@ async def edit( user_limit: int = ..., position: int = ..., sync_permissions: bool = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., rtc_region: Optional[Union[str, VoiceRegion]] = ..., video_quality_mode: VideoQualityMode = ..., @@ -1563,8 +1579,7 @@ async def edit( slowmode_delay: Optional[int] = ..., flags: ChannelFlags = ..., reason: Optional[str] = ..., - ) -> VoiceChannel: - ... + ) -> VoiceChannel: ... async def edit( self, @@ -1574,7 +1589,7 @@ async def edit( user_limit: int = MISSING, position: int = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING, rtc_region: Optional[Union[str, VoiceRegion]] = MISSING, video_quality_mode: VideoQualityMode = MISSING, @@ -1685,7 +1700,7 @@ async def edit( # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - async def delete_messages(self, messages: Iterable[Snowflake]) -> None: + async def delete_messages(self, messages: Iterable[Snowflake[MessageId]]) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -1726,7 +1741,7 @@ async def delete_messages(self, messages: Iterable[Snowflake]) -> None: return # do nothing if len(messages) == 1: - message_id: int = messages[0].id + message_id = messages[0].id await self._state.http.delete_message(self.id, message_id) return @@ -2023,7 +2038,9 @@ def _update(self, guild: Guild, data: StageChannelPayload) -> None: self.topic: Optional[str] = data.get("topic") self.nsfw: bool = data.get("nsfw", False) self.slowmode_delay: int = data.get("rate_limit_per_user", 0) - self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") + self.last_message_id: Optional[MessageId] = utils._get_as_snowflake( + data, "last_message_id", MessageId + ) async def _get_channel(self): return self @@ -2087,7 +2104,7 @@ async def clone( bitrate: int = MISSING, # user_limit: int = MISSING, position: int = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: int = MISSING, rtc_region: Optional[Union[str, VoiceRegion]] = MISSING, video_quality_mode: VideoQualityMode = MISSING, @@ -2208,7 +2225,8 @@ def last_message(self) -> Optional[Message]: """ return self._state._get_message(self.last_message_id) if self.last_message_id else None - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the given message ID. This is useful if you want to work with a message and only have its ID without @@ -2346,11 +2364,10 @@ async def edit( self, *, position: int, - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... # only passing `sync_permissions` may or may not return a channel, # depending on whether the channel is in a category @@ -2360,8 +2377,7 @@ async def edit( *, sync_permissions: bool, reason: Optional[str] = ..., - ) -> Optional[StageChannel]: - ... + ) -> Optional[StageChannel]: ... @overload async def edit( @@ -2372,7 +2388,7 @@ async def edit( user_limit: int = ..., position: int = ..., sync_permissions: bool = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., rtc_region: Optional[Union[str, VoiceRegion]] = ..., video_quality_mode: VideoQualityMode = ..., @@ -2380,8 +2396,7 @@ async def edit( slowmode_delay: Optional[int] = ..., flags: ChannelFlags = ..., reason: Optional[str] = ..., - ) -> StageChannel: - ... + ) -> StageChannel: ... async def edit( self, @@ -2391,7 +2406,7 @@ async def edit( user_limit: int = MISSING, position: int = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING, rtc_region: Optional[Union[str, VoiceRegion]] = MISSING, video_quality_mode: VideoQualityMode = MISSING, @@ -2510,7 +2525,7 @@ async def edit( # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - async def delete_messages(self, messages: Iterable[Snowflake]) -> None: + async def delete_messages(self, messages: Iterable[Snowflake[MessageId]]) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -2551,7 +2566,7 @@ async def delete_messages(self, messages: Iterable[Snowflake]) -> None: return # do nothing if len(messages) == 1: - message_id: int = messages[0].id + message_id = messages[0].id await self._state.http.delete_message(self.id, message_id) return @@ -2742,7 +2757,7 @@ async def create_webhook( return Webhook.from_state(data, state=self._state) -class CategoryChannel(disnake.abc.GuildChannel, Hashable): +class CategoryChannel(disnake.abc.GuildChannel, Hashable[CategoryId]): """Represents a Discord channel category. These are useful to group channels to logical compartments. @@ -2800,7 +2815,7 @@ def __init__( self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload ) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id = CategoryId(int(data["id"])) self._update(guild, data) def __repr__(self) -> str: @@ -2810,7 +2825,9 @@ def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: self.guild: Guild = guild # apparently this can be nullable in the case of a bad api deploy self.name: str = data.get("name") or "" - self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.category_id: Optional[CategoryId] = utils._get_as_snowflake( + data, "parent_id", CategoryId + ) self._flags = data.get("flags", 0) self.nsfw: bool = data.get("nsfw", False) self.position: int = data["position"] @@ -2911,8 +2928,7 @@ async def edit( *, position: int, reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def edit( @@ -2924,8 +2940,7 @@ async def edit( overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., flags: ChannelFlags = ..., reason: Optional[str] = ..., - ) -> CategoryChannel: - ... + ) -> CategoryChannel: ... async def edit( self, @@ -3011,8 +3026,7 @@ async def move( offset: int = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -3022,8 +3036,7 @@ async def move( offset: int = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -3033,8 +3046,7 @@ async def move( offset: int = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @overload async def move( @@ -3044,8 +3056,7 @@ async def move( offset: int = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... @utils.copy_doc(disnake.abc.GuildChannel.move) async def move(self, **kwargs: Any) -> None: @@ -3224,7 +3235,7 @@ class ThreadWithMessage(NamedTuple): message: Message -class ThreadOnlyGuildChannel(disnake.abc.GuildChannel, Hashable): +class ThreadOnlyGuildChannel(disnake.abc.GuildChannel, Hashable[ChannelId]): __slots__ = ( "id", "name", @@ -3255,7 +3266,7 @@ def __init__( data: Union[ForumChannelPayload, MediaChannelPayload], ) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id = ChannelId(int(data["id"])) self._type: int = data["type"] self._update(guild, data) @@ -3277,12 +3288,16 @@ def _update(self, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPay self.guild: Guild = guild # apparently this can be nullable in the case of a bad api deploy self.name: str = data.get("name") or "" - self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.category_id: Optional[CategoryId] = utils._get_as_snowflake( + data, "parent_id", CategoryId + ) self.topic: Optional[str] = data.get("topic") self.position: int = data["position"] self._flags = data.get("flags", 0) self.nsfw: bool = data.get("nsfw", False) - self.last_thread_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") + self.last_thread_id: Optional[ThreadId] = utils._get_as_snowflake( + data, "last_message_id", ThreadId + ) self.default_auto_archive_duration: ThreadArchiveDurationLiteral = data.get( "default_auto_archive_duration", 1440 ) @@ -3297,8 +3312,8 @@ def _update(self, guild: Guild, data: Union[ForumChannelPayload, MediaChannelPay default_reaction_emoji = data.get("default_reaction_emoji") or {} # emoji_id may be `0`, use `None` instead - self._default_reaction_emoji_id: Optional[int] = ( - utils._get_as_snowflake(default_reaction_emoji, "emoji_id") or None + self._default_reaction_emoji_id: Optional[EmojiId] = ( + utils._get_as_snowflake(default_reaction_emoji, "emoji_id", EmojiId) or None ) self._default_reaction_emoji_name: Optional[str] = default_reaction_emoji.get("emoji_name") @@ -3421,7 +3436,8 @@ async def trigger_typing(self) -> None: def typing(self) -> Typing: return Typing(self) - def get_thread(self, thread_id: int, /) -> Optional[Thread]: + @overload_get + def get_thread(self, thread_id: ThreadId, /) -> Optional[Thread]: """Returns a thread with the given ID. Parameters @@ -3454,8 +3470,7 @@ async def create_thread( view: View = ..., components: Components = ..., reason: Optional[str] = None, - ) -> ThreadWithMessage: - ... + ) -> ThreadWithMessage: ... @overload async def create_thread( @@ -3475,8 +3490,7 @@ async def create_thread( view: View = ..., components: Components = ..., reason: Optional[str] = None, - ) -> ThreadWithMessage: - ... + ) -> ThreadWithMessage: ... @overload async def create_thread( @@ -3496,8 +3510,7 @@ async def create_thread( view: View = ..., components: Components = ..., reason: Optional[str] = None, - ) -> ThreadWithMessage: - ... + ) -> ThreadWithMessage: ... @overload async def create_thread( @@ -3517,8 +3530,7 @@ async def create_thread( view: View = ..., components: Components = ..., reason: Optional[str] = None, - ) -> ThreadWithMessage: - ... + ) -> ThreadWithMessage: ... async def create_thread( self, @@ -3948,11 +3960,10 @@ async def edit( self, *, position: int, - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... # only passing `sync_permissions` may or may not return a channel, # depending on whether the channel is in a category @@ -3962,8 +3973,7 @@ async def edit( *, sync_permissions: bool, reason: Optional[str] = ..., - ) -> Optional[ForumChannel]: - ... + ) -> Optional[ForumChannel]: ... @overload async def edit( @@ -3974,7 +3984,7 @@ async def edit( position: int = ..., nsfw: bool = ..., sync_permissions: bool = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., slowmode_delay: Optional[int] = ..., default_thread_slowmode_delay: Optional[int] = ..., default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = ..., @@ -3986,8 +3996,7 @@ async def edit( default_sort_order: Optional[ThreadSortOrder] = ..., default_layout: ThreadLayout = ..., reason: Optional[str] = ..., - ) -> ForumChannel: - ... + ) -> ForumChannel: ... async def edit( self, @@ -3997,7 +4006,7 @@ async def edit( position: int = MISSING, nsfw: bool = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -4147,7 +4156,7 @@ async def clone( topic: Optional[str] = MISSING, position: int = MISSING, nsfw: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -4378,11 +4387,10 @@ async def edit( self, *, position: int, - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., sync_permissions: bool = ..., reason: Optional[str] = ..., - ) -> None: - ... + ) -> None: ... # only passing `sync_permissions` may or may not return a channel, # depending on whether the channel is in a category @@ -4392,8 +4400,7 @@ async def edit( *, sync_permissions: bool, reason: Optional[str] = ..., - ) -> Optional[MediaChannel]: - ... + ) -> Optional[MediaChannel]: ... @overload async def edit( @@ -4404,7 +4411,7 @@ async def edit( position: int = ..., nsfw: bool = ..., sync_permissions: bool = ..., - category: Optional[Snowflake] = ..., + category: Optional[Snowflake[CategoryId]] = ..., slowmode_delay: Optional[int] = ..., default_thread_slowmode_delay: Optional[int] = ..., default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = ..., @@ -4415,8 +4422,7 @@ async def edit( default_reaction: Optional[Union[str, Emoji, PartialEmoji]] = ..., default_sort_order: Optional[ThreadSortOrder] = ..., reason: Optional[str] = ..., - ) -> MediaChannel: - ... + ) -> MediaChannel: ... async def edit( self, @@ -4426,7 +4432,7 @@ async def edit( position: int = MISSING, nsfw: bool = MISSING, sync_permissions: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -4548,7 +4554,7 @@ async def clone( topic: Optional[str] = MISSING, position: int = MISSING, nsfw: bool = MISSING, - category: Optional[Snowflake] = MISSING, + category: Optional[Snowflake[CategoryId]] = MISSING, slowmode_delay: Optional[int] = MISSING, default_thread_slowmode_delay: Optional[int] = MISSING, default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING, @@ -4663,7 +4669,7 @@ async def clone( ) -class DMChannel(disnake.abc.Messageable, Hashable): +class DMChannel(disnake.abc.Messageable, Hashable[PrivateChannelId]): """Represents a Discord direct message channel. .. collapse:: operations @@ -4716,7 +4722,7 @@ def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPay self.recipient = state.store_user(recipients[0]) # type: ignore self.me: ClientUser = me - self.id: int = int(data["id"]) + self.id = PrivateChannelId(int(data["id"])) self.last_pin_timestamp: Optional[datetime.datetime] = utils.parse_time( data.get("last_pin_timestamp") ) @@ -4734,7 +4740,9 @@ def __repr__(self) -> str: return f"" @classmethod - def _from_message(cls, state: ConnectionState, channel_id: int, user_id: int) -> Self: + def _from_message( + cls, state: ConnectionState, channel_id: PrivateChannelId, user_id: UserId + ) -> Self: self = cls.__new__(cls) self._state = state self.id = channel_id @@ -4806,7 +4814,8 @@ def permissions_for( """ return Permissions.private_channel() - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the given message ID. This is useful if you want to work with a message and only have its ID without @@ -4829,7 +4838,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: return PartialMessage(channel=self, id=message_id) -class GroupChannel(disnake.abc.Messageable, Hashable): +class GroupChannel(disnake.abc.Messageable, Hashable[PrivateChannelId]): """Represents a Discord group channel. .. collapse:: operations @@ -4877,12 +4886,12 @@ def __init__( self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload ) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id = PrivateChannelId(int(data["id"])) self.me: ClientUser = me self._update_group(data) def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id") + self.owner_id: Optional[UserId] = utils._get_as_snowflake(data, "owner_id", UserId) self._icon: Optional[str] = data.get("icon") self.name: Optional[str] = data.get("name") self.recipients: List[User] = [ @@ -4932,7 +4941,7 @@ def created_at(self) -> datetime.datetime: def permissions_for( self, - obj: Snowflake, + obj: Snowflake[UserId], /, *, ignore_timeout: bool = MISSING, @@ -4983,7 +4992,7 @@ async def leave(self) -> None: await self._state.http.leave_group(self.id) -class PartialMessageable(disnake.abc.Messageable, Hashable): +class PartialMessageable(disnake.abc.Messageable, Hashable[ChannelId]): """Represents a partial messageable to aid with working messageable channels when only a channel ID is present. @@ -5015,15 +5024,18 @@ class PartialMessageable(disnake.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None) -> None: + def __init__( + self, state: ConnectionState, id: ChannelId, type: Optional[ChannelType] = None + ) -> None: self._state: ConnectionState = state - self.id: int = id + self.id = id self.type: Optional[ChannelType] = type async def _get_channel(self) -> PartialMessageable: return self - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the given message ID. This is useful if you want to work with a message and only have its ID without @@ -5089,7 +5101,7 @@ def _threaded_guild_channel_factory(channel_type: int): def _channel_type_factory( - cls: Union[Type[disnake.abc.GuildChannel], Type[Thread]] + cls: Union[Type[disnake.abc.GuildChannel], Type[Thread]], ) -> List[ChannelType]: return { disnake.abc.GuildChannel: list(ChannelType.__members__.values()), diff --git a/disnake/client.py b/disnake/client.py index 80b3d67c65..3a2bc9b397 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -73,6 +73,23 @@ from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .template import Template from .threads import Thread +from .types.ids import ( + ApplicationCommandId, + ApplicationId, + ChannelId, + EmojiId, + GuildId, + MessageId, + ObjectId, + PrivateChannelId, + StickerId, + ThreadId, + UserId, + WebhookId, + overload_fetch, + overload_get, + overload_get_seq, +) from .ui.view import View from .user import ClientUser, User from .utils import MISSING, deprecated @@ -82,7 +99,7 @@ from .widget import Widget if TYPE_CHECKING: - from typing_extensions import NotRequired + from typing_extensions import Never, NotRequired from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand, MessageCommand, SlashCommand, UserCommand @@ -392,7 +409,7 @@ def __init__( proxy_auth: Optional[aiohttp.BasicAuth] = None, assume_unsync_clock: bool = True, max_messages: Optional[int] = 1000, - application_id: Optional[int] = None, + application_id: Optional[ApplicationId] = None, heartbeat_timeout: float = 60.0, guild_ready_timeout: float = 2.0, allowed_mentions: Optional[AllowedMentions] = None, @@ -481,7 +498,7 @@ def __init__( # internals def _get_websocket( - self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None + self, guild_id: Optional[GuildId] = None, *, shard_id: Optional[int] = None ) -> DiscordWebSocket: return self.ws @@ -489,7 +506,7 @@ def _get_state( self, *, max_messages: Optional[int], - application_id: Optional[int], + application_id: Optional[ApplicationId], heartbeat_timeout: float, guild_ready_timeout: float, allowed_mentions: Optional[AllowedMentions], @@ -650,7 +667,8 @@ def global_message_commands(self) -> List[APIMessageCommand]: if isinstance(cmd, APIMessageCommand) ] - def get_message(self, id: int) -> Optional[Message]: + @overload_get + def get_message(self, id: MessageId) -> Optional[Message]: """Gets the message with the given ID from the bot's message cache. Parameters @@ -665,17 +683,28 @@ def get_message(self, id: int) -> Optional[Message]: """ return utils.get(self.cached_messages, id=id) + @overload + async def get_or_fetch_user( + self, user_id: UserId, *, strict: Literal[False] = ... + ) -> Optional[User]: ... + @overload + async def get_or_fetch_user( + self, user_id: ObjectId, *, strict: Literal[False] = ... + ) -> None: ... @overload async def get_or_fetch_user( self, user_id: int, *, strict: Literal[False] = ... - ) -> Optional[User]: - ... - + ) -> Optional[User]: ... + @overload + async def get_or_fetch_user(self, user_id: UserId, *, strict: Literal[True]) -> User: ... + @overload + async def get_or_fetch_user(self, user_id: ObjectId, *, strict: Literal[True]) -> Never: ... @overload - async def get_or_fetch_user(self, user_id: int, *, strict: Literal[True]) -> User: - ... + async def get_or_fetch_user(self, user_id: int, *, strict: Literal[True]) -> User: ... - async def get_or_fetch_user(self, user_id: int, *, strict: bool = False) -> Optional[User]: + async def get_or_fetch_user( + self, user_id: Union[ObjectId, int], *, strict: bool = False + ) -> Optional[User]: """|coro| Tries to get the user from the cache. If it fails, @@ -1383,7 +1412,19 @@ def users(self) -> List[User]: """List[:class:`~disnake.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + @overload + def get_channel(self, id: ChannelId, /) -> Optional[GuildChannel]: ... + @overload + def get_channel(self, id: ThreadId, /) -> Optional[Thread]: ... + @overload + def get_channel(self, id: PrivateChannelId, /) -> Optional[PrivateChannel]: ... + @overload + def get_channel(self, id: ObjectId, /) -> None: ... + @overload + def get_channel(self, id: int, /) -> Union[GuildChannel, Thread, PrivateChannel, None]: ... + def get_channel( + self, id: Union[ObjectId, int], / + ) -> Union[GuildChannel, Thread, PrivateChannel, None]: """Returns a channel or thread with the given ID. Parameters @@ -1396,10 +1437,23 @@ def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, Privat Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] The returned channel or ``None`` if not found. """ - return self._connection.get_channel(id) + return self._connection.get_channel(id) # type: ignore + # invalid ID still results in PartialMessageable, but it's effectively useless + @overload + def get_partial_messageable( + self, id: ChannelId, *, type: Optional[ChannelType] = None + ) -> PartialMessageable: ... + @overload + def get_partial_messageable( + self, id: ObjectId, *, type: Optional[ChannelType] = None + ) -> Never: ... + @overload def get_partial_messageable( self, id: int, *, type: Optional[ChannelType] = None + ) -> PartialMessageable: ... + def get_partial_messageable( + self, id: Union[ChannelId, int], *, type: Optional[ChannelType] = None ) -> PartialMessageable: """Returns a partial messageable with the given channel ID. @@ -1420,9 +1474,10 @@ def get_partial_messageable( :class:`.PartialMessageable` The partial messageable """ - return PartialMessageable(state=self._connection, id=id, type=type) + return PartialMessageable(state=self._connection, id=ChannelId(id), type=type) - def get_stage_instance(self, id: int, /) -> Optional[StageInstance]: + @overload_get + def get_stage_instance(self, id: ChannelId, /) -> Optional[StageInstance]: """Returns a stage instance with the given stage channel ID. .. versionadded:: 2.0 @@ -1444,7 +1499,8 @@ def get_stage_instance(self, id: int, /) -> Optional[StageInstance]: if isinstance(channel, StageChannel): return channel.instance - def get_guild(self, id: int, /) -> Optional[Guild]: + @overload_get + def get_guild(self, id: GuildId, /) -> Optional[Guild]: """Returns a guild with the given ID. Parameters @@ -1459,7 +1515,8 @@ def get_guild(self, id: int, /) -> Optional[Guild]: """ return self._connection._get_guild(id) - def get_user(self, id: int, /) -> Optional[User]: + @overload_get + def get_user(self, id: UserId, /) -> Optional[User]: """Returns a user with the given ID. Parameters @@ -1474,7 +1531,8 @@ def get_user(self, id: int, /) -> Optional[User]: """ return self._connection.get_user(id) - def get_emoji(self, id: int, /) -> Optional[Emoji]: + @overload_get + def get_emoji(self, id: EmojiId, /) -> Optional[Emoji]: """Returns an emoji with the given ID. Parameters @@ -1489,7 +1547,8 @@ def get_emoji(self, id: int, /) -> Optional[Emoji]: """ return self._connection.get_emoji(id) - def get_sticker(self, id: int, /) -> Optional[GuildSticker]: + @overload_get + def get_sticker(self, id: StickerId, /) -> Optional[GuildSticker]: """Returns a guild sticker with the given ID. .. versionadded:: 2.0 @@ -1546,7 +1605,8 @@ def get_all_members(self) -> Generator[Member, None, None]: for guild in self.guilds: yield from guild.members - def get_guild_application_commands(self, guild_id: int) -> List[APIApplicationCommand]: + @overload_get_seq + def get_guild_application_commands(self, guild_id: GuildId) -> List[APIApplicationCommand]: """Returns a list of all application commands in the guild with the given ID. Parameters @@ -1562,7 +1622,8 @@ def get_guild_application_commands(self, guild_id: int) -> List[APIApplicationCo data = self._connection._guild_application_commands.get(guild_id, {}) return list(data.values()) - def get_guild_slash_commands(self, guild_id: int) -> List[APISlashCommand]: + @overload_get_seq + def get_guild_slash_commands(self, guild_id: GuildId) -> List[APISlashCommand]: """Returns a list of all slash commands in the guild with the given ID. Parameters @@ -1578,7 +1639,8 @@ def get_guild_slash_commands(self, guild_id: int) -> List[APISlashCommand]: data = self._connection._guild_application_commands.get(guild_id, {}) return [cmd for cmd in data.values() if isinstance(cmd, APISlashCommand)] - def get_guild_user_commands(self, guild_id: int) -> List[APIUserCommand]: + @overload_get_seq + def get_guild_user_commands(self, guild_id: GuildId) -> List[APIUserCommand]: """Returns a list of all user commands in the guild with the given ID. Parameters @@ -1594,7 +1656,8 @@ def get_guild_user_commands(self, guild_id: int) -> List[APIUserCommand]: data = self._connection._guild_application_commands.get(guild_id, {}) return [cmd for cmd in data.values() if isinstance(cmd, APIUserCommand)] - def get_guild_message_commands(self, guild_id: int) -> List[APIMessageCommand]: + @overload_get_seq + def get_guild_message_commands(self, guild_id: GuildId) -> List[APIMessageCommand]: """Returns a list of all message commands in the guild with the given ID. Parameters @@ -1610,7 +1673,8 @@ def get_guild_message_commands(self, guild_id: int) -> List[APIMessageCommand]: data = self._connection._guild_application_commands.get(guild_id, {}) return [cmd for cmd in data.values() if isinstance(cmd, APIMessageCommand)] - def get_global_command(self, id: int) -> Optional[APIApplicationCommand]: + @overload_get + def get_global_command(self, id: ApplicationCommandId) -> Optional[APIApplicationCommand]: """Returns a global application command with the given ID. Parameters @@ -1625,7 +1689,9 @@ def get_global_command(self, id: int) -> Optional[APIApplicationCommand]: """ return self._connection._get_global_application_command(id) - def get_guild_command(self, guild_id: int, id: int) -> Optional[APIApplicationCommand]: + def get_guild_command( + self, guild_id: GuildId, id: ApplicationCommandId + ) -> Optional[APIApplicationCommand]: """Returns a guild application command with the given guild ID and application command ID. Parameters @@ -1661,8 +1727,9 @@ def get_global_command_named( """ return self._connection._get_global_command_named(name, cmd_type) + @overload_get def get_guild_command_named( - self, guild_id: int, name: str, cmd_type: Optional[ApplicationCommandType] = None + self, guild_id: GuildId, name: str, cmd_type: Optional[ApplicationCommandType] = None ) -> Optional[APIApplicationCommand]: """Returns a guild application command matching the given name. @@ -1998,7 +2065,8 @@ async def fetch_template(self, code: Union[Template, str]) -> Template: data = await self.http.get_template(code) return Template(data=data, state=self._connection) - async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild: + @overload_fetch + async def fetch_guild(self, guild_id: GuildId, /, *, with_counts: bool = True) -> Guild: """|coro| Retrieves a :class:`.Guild` from the given ID. @@ -2037,9 +2105,10 @@ async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Gu data = await self.http.get_guild(guild_id, with_counts=with_counts) return Guild(data=data, state=self._connection) + @overload_fetch async def fetch_guild_preview( self, - guild_id: int, + guild_id: GuildId, /, ) -> GuildPreview: """|coro| @@ -2167,7 +2236,8 @@ def guild_builder(self, name: str) -> GuildBuilder: """ return GuildBuilder(name=name, state=self._connection) - async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance: + @overload_fetch + async def fetch_stage_instance(self, channel_id: ChannelId, /) -> StageInstance: """|coro| Retrieves a :class:`.StageInstance` with the given ID. @@ -2296,7 +2366,17 @@ async def delete_invite(self, invite: Union[Invite, str]) -> None: # Voice region stuff - async def fetch_voice_regions(self, guild_id: Optional[int] = None) -> List[VoiceRegion]: + @overload + async def fetch_voice_regions( + self, guild_id: Optional[GuildId] = None + ) -> List[VoiceRegion]: ... + @overload + async def fetch_voice_regions(self, guild_id: ObjectId) -> Never: ... + @overload + async def fetch_voice_regions(self, guild_id: int) -> List[VoiceRegion]: ... + async def fetch_voice_regions( + self, guild_id: Union[ObjectId, int, None] = None + ) -> List[VoiceRegion]: """Retrieves a list of :class:`.VoiceRegion`\\s. Retrieves voice regions for the user, or a guild if provided. @@ -2323,7 +2403,8 @@ async def fetch_voice_regions(self, guild_id: Optional[int] = None) -> List[Voic # Miscellaneous stuff - async def fetch_widget(self, guild_id: int, /) -> Widget: + @overload_fetch + async def fetch_widget(self, guild_id: GuildId, /) -> Widget: """|coro| Retrieves a :class:`.Widget` for the given guild ID. @@ -2372,7 +2453,8 @@ async def application_info(self) -> AppInfo: data["rpc_origins"] = None return AppInfo(self._connection, data) - async def fetch_user(self, user_id: int, /) -> User: + @overload_fetch + async def fetch_user(self, user_id: UserId, /) -> User: """|coro| Retrieves a :class:`~disnake.User` based on their ID. @@ -2403,10 +2485,21 @@ async def fetch_user(self, user_id: int, /) -> User: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) + @overload + async def fetch_channel(self, channel_id: ChannelId, /) -> GuildChannel: ... + @overload + async def fetch_channel(self, channel_id: PrivateChannelId, /) -> PrivateChannel: ... + @overload + async def fetch_channel(self, channel_id: ThreadId, /) -> Thread: ... + @overload + async def fetch_channel(self, channel_id: ObjectId, /) -> Never: ... + @overload async def fetch_channel( - self, - channel_id: int, - /, + self, channel_id: int, / + ) -> Union[GuildChannel, PrivateChannel, Thread]: ... + + async def fetch_channel( + self, channel_id: Union[ObjectId, int], / ) -> Union[GuildChannel, PrivateChannel, Thread]: """|coro| @@ -2457,7 +2550,8 @@ async def fetch_channel( return channel - async def fetch_webhook(self, webhook_id: int, /) -> Webhook: + @overload_fetch + async def fetch_webhook(self, webhook_id: WebhookId, /) -> Webhook: """|coro| Retrieves a :class:`.Webhook` with the given ID. @@ -2484,7 +2578,8 @@ async def fetch_webhook(self, webhook_id: int, /) -> Webhook: data = await self.http.get_webhook(webhook_id) return Webhook.from_state(data, state=self._connection) - async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]: + @overload_fetch + async def fetch_sticker(self, sticker_id: StickerId, /) -> Union[StandardSticker, GuildSticker]: """|coro| Retrieves a :class:`.Sticker` with the given ID. @@ -2570,7 +2665,7 @@ async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """ return await self.fetch_sticker_packs() - async def create_dm(self, user: Snowflake) -> DMChannel: + async def create_dm(self, user: Snowflake[UserId]) -> DMChannel: """|coro| Creates a :class:`.DMChannel` with the given user. @@ -2598,7 +2693,7 @@ async def create_dm(self, user: Snowflake) -> DMChannel: data = await state.http.start_private_message(user.id) return state.add_dm_channel(data) - def add_view(self, view: View, *, message_id: Optional[int] = None) -> None: + def add_view(self, view: View, *, message_id: Optional[MessageId] = None) -> None: """Registers a :class:`~disnake.ui.View` for persistent listening. This method should be used for when a view is comprised of components @@ -2668,7 +2763,8 @@ async def fetch_global_commands( """ return await self._connection.fetch_global_commands(with_localizations=with_localizations) - async def fetch_global_command(self, command_id: int) -> APIApplicationCommand: + @overload_fetch + async def fetch_global_command(self, command_id: ApplicationCommandId) -> APIApplicationCommand: """|coro| Retrieves a global application command. @@ -2688,22 +2784,20 @@ async def fetch_global_command(self, command_id: int) -> APIApplicationCommand: return await self._connection.fetch_global_command(command_id) @overload - async def create_global_command(self, application_command: SlashCommand) -> APISlashCommand: - ... + async def create_global_command(self, application_command: SlashCommand) -> APISlashCommand: ... @overload - async def create_global_command(self, application_command: UserCommand) -> APIUserCommand: - ... + async def create_global_command(self, application_command: UserCommand) -> APIUserCommand: ... @overload - async def create_global_command(self, application_command: MessageCommand) -> APIMessageCommand: - ... + async def create_global_command( + self, application_command: MessageCommand + ) -> APIMessageCommand: ... @overload async def create_global_command( self, application_command: ApplicationCommand - ) -> APIApplicationCommand: - ... + ) -> APIApplicationCommand: ... async def create_global_command( self, application_command: ApplicationCommand @@ -2729,30 +2823,26 @@ async def create_global_command( @overload async def edit_global_command( - self, command_id: int, new_command: SlashCommand - ) -> APISlashCommand: - ... + self, command_id: ApplicationCommandId, new_command: SlashCommand + ) -> APISlashCommand: ... @overload async def edit_global_command( - self, command_id: int, new_command: UserCommand - ) -> APIUserCommand: - ... + self, command_id: ApplicationCommandId, new_command: UserCommand + ) -> APIUserCommand: ... @overload async def edit_global_command( - self, command_id: int, new_command: MessageCommand - ) -> APIMessageCommand: - ... + self, command_id: ApplicationCommandId, new_command: MessageCommand + ) -> APIMessageCommand: ... @overload async def edit_global_command( - self, command_id: int, new_command: ApplicationCommand - ) -> APIApplicationCommand: - ... + self, command_id: ApplicationCommandId, new_command: ApplicationCommand + ) -> APIApplicationCommand: ... async def edit_global_command( - self, command_id: int, new_command: ApplicationCommand + self, command_id: ApplicationCommandId, new_command: ApplicationCommand ) -> APIApplicationCommand: """|coro| @@ -2775,7 +2865,7 @@ async def edit_global_command( new_command.localize(self.i18n) return await self._connection.edit_global_command(command_id, new_command) - async def delete_global_command(self, command_id: int) -> None: + async def delete_global_command(self, command_id: ApplicationCommandId) -> None: """|coro| Deletes a global application command. @@ -2814,9 +2904,10 @@ async def bulk_overwrite_global_commands( # Application commands (guild) + @overload_fetch async def fetch_guild_commands( self, - guild_id: int, + guild_id: GuildId, *, with_localizations: bool = True, ) -> List[APIApplicationCommand]: @@ -2844,7 +2935,9 @@ async def fetch_guild_commands( guild_id, with_localizations=with_localizations ) - async def fetch_guild_command(self, guild_id: int, command_id: int) -> APIApplicationCommand: + async def fetch_guild_command( + self, guild_id: GuildId, command_id: ApplicationCommandId + ) -> APIApplicationCommand: """|coro| Retrieves a guild application command. @@ -2867,30 +2960,26 @@ async def fetch_guild_command(self, guild_id: int, command_id: int) -> APIApplic @overload async def create_guild_command( - self, guild_id: int, application_command: SlashCommand - ) -> APISlashCommand: - ... + self, guild_id: GuildId, application_command: SlashCommand + ) -> APISlashCommand: ... @overload async def create_guild_command( - self, guild_id: int, application_command: UserCommand - ) -> APIUserCommand: - ... + self, guild_id: GuildId, application_command: UserCommand + ) -> APIUserCommand: ... @overload async def create_guild_command( - self, guild_id: int, application_command: MessageCommand - ) -> APIMessageCommand: - ... + self, guild_id: GuildId, application_command: MessageCommand + ) -> APIMessageCommand: ... @overload async def create_guild_command( - self, guild_id: int, application_command: ApplicationCommand - ) -> APIApplicationCommand: - ... + self, guild_id: GuildId, application_command: ApplicationCommand + ) -> APIApplicationCommand: ... async def create_guild_command( - self, guild_id: int, application_command: ApplicationCommand + self, guild_id: GuildId, application_command: ApplicationCommand ) -> APIApplicationCommand: """|coro| @@ -2915,30 +3004,26 @@ async def create_guild_command( @overload async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: SlashCommand - ) -> APISlashCommand: - ... + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: SlashCommand + ) -> APISlashCommand: ... @overload async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: UserCommand - ) -> APIUserCommand: - ... + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: UserCommand + ) -> APIUserCommand: ... @overload async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: MessageCommand - ) -> APIMessageCommand: - ... + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: MessageCommand + ) -> APIMessageCommand: ... @overload async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: ApplicationCommand - ) -> APIApplicationCommand: - ... + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: ApplicationCommand + ) -> APIApplicationCommand: ... async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: ApplicationCommand + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: ApplicationCommand ) -> APIApplicationCommand: """|coro| @@ -2963,7 +3048,9 @@ async def edit_guild_command( new_command.localize(self.i18n) return await self._connection.edit_guild_command(guild_id, command_id, new_command) - async def delete_guild_command(self, guild_id: int, command_id: int) -> None: + async def delete_guild_command( + self, guild_id: GuildId, command_id: ApplicationCommandId + ) -> None: """|coro| Deletes a guild application command. @@ -2980,7 +3067,7 @@ async def delete_guild_command(self, guild_id: int, command_id: int) -> None: await self._connection.delete_guild_command(guild_id, command_id) async def bulk_overwrite_guild_commands( - self, guild_id: int, application_commands: List[ApplicationCommand] + self, guild_id: GuildId, application_commands: List[ApplicationCommand] ) -> List[APIApplicationCommand]: """|coro| @@ -3007,7 +3094,7 @@ async def bulk_overwrite_guild_commands( # Application command permissions async def bulk_fetch_command_permissions( - self, guild_id: int + self, guild_id: GuildId ) -> List[GuildApplicationCommandPermissions]: """|coro| @@ -3023,7 +3110,7 @@ async def bulk_fetch_command_permissions( return await self._connection.bulk_fetch_command_permissions(guild_id) async def fetch_command_permissions( - self, guild_id: int, command_id: int + self, guild_id: GuildId, command_id: ApplicationCommandId ) -> GuildApplicationCommandPermissions: """|coro| @@ -3138,8 +3225,8 @@ def entitlements( limit: Optional[int] = 100, before: Optional[SnowflakeTime] = None, after: Optional[SnowflakeTime] = None, - user: Optional[Snowflake] = None, - guild: Optional[Snowflake] = None, + user: Optional[Snowflake[UserId]] = None, + guild: Optional[Snowflake[GuildId]] = None, skus: Optional[Sequence[Snowflake]] = None, exclude_ended: bool = False, oldest_first: bool = False, diff --git a/disnake/colour.py b/disnake/colour.py index 4bd6585ea2..e49a88cde1 100644 --- a/disnake/colour.py +++ b/disnake/colour.py @@ -61,10 +61,10 @@ def __init__(self, value: int) -> None: def _get_byte(self, byte: int) -> int: return (self.value >> (8 * byte)) & 0xFF - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, Colour) and self.value == other.value - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __str__(self) -> str: diff --git a/disnake/embeds.py b/disnake/embeds.py index 1866d8d7eb..f1e4a69abb 100644 --- a/disnake/embeds.py +++ b/disnake/embeds.py @@ -55,7 +55,7 @@ def __repr__(self) -> str: def __getattr__(self, attr: str) -> None: return None - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, EmbedProxy) and self.__dict__ == other.__dict__ @@ -322,7 +322,7 @@ def __bool__(self) -> bool: ) ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, Embed): return False for slot in self.__slots__: diff --git a/disnake/emoji.py b/disnake/emoji.py index badedbce86..788e2de46b 100644 --- a/disnake/emoji.py +++ b/disnake/emoji.py @@ -5,7 +5,9 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union from .asset import Asset, AssetMixin +from .mixins import Hashable from .partial_emoji import PartialEmoji, _EmojiTag +from .types.ids import EmojiId, GuildId, RoleId from .user import User from .utils import MISSING, SnowflakeList, snowflake_time @@ -22,7 +24,7 @@ from .types.emoji import Emoji as EmojiPayload -class Emoji(_EmojiTag, AssetMixin): +class Emoji(_EmojiTag, AssetMixin, Hashable[EmojiId]): """Represents a custom emoji. Depending on the way this object was created, some of the attributes can @@ -88,18 +90,18 @@ class Emoji(_EmojiTag, AssetMixin): def __init__( self, *, guild: Union[Guild, GuildPreview], state: ConnectionState, data: EmojiPayload ) -> None: - self.guild_id: int = guild.id + self.guild_id: GuildId = guild.id self._state: ConnectionState = state self._from_data(data) def _from_data(self, emoji: EmojiPayload) -> None: self.require_colons: bool = emoji.get("require_colons", False) self.managed: bool = emoji.get("managed", False) - self.id: int = int(emoji["id"]) # type: ignore + self.id = EmojiId(int(emoji["id"])) # type: ignore self.name: str = emoji["name"] # type: ignore self.animated: bool = emoji.get("animated", False) self.available: bool = emoji.get("available", True) - self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", []))) + self._roles: SnowflakeList[RoleId] = SnowflakeList(map(int, emoji.get("roles", []))) # type: ignore user = emoji.get("user") self.user: Optional[User] = User(state=self._state, data=user) if user else None @@ -121,14 +123,11 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _EmojiTag) and self.id == other.id - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __hash__(self) -> int: - return self.id >> 22 + def __ne__(self, other: object) -> bool: + return not isinstance(other, _EmojiTag) or self.id != other.id @property def created_at(self) -> datetime: @@ -199,7 +198,11 @@ async def delete(self, *, reason: Optional[str] = None) -> None: await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) async def edit( - self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None + self, + *, + name: str = MISSING, + roles: List[Snowflake[RoleId]] = MISSING, + reason: Optional[str] = None, ) -> Emoji: """|coro| diff --git a/disnake/guild.py b/disnake/guild.py index 97ea1e80ac..f6311272d6 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -77,6 +77,22 @@ from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember +from .types.ids import ( + ApplicationCommandId, + ApplicationId, + CategoryId, + ChannelId, + EmojiId, + GuildId, + MemberId, + ObjectId, + RoleId, + StickerId, + ThreadId, + UserId, + overload_fetch, + overload_get, +) from .user import User from .voice_region import VoiceRegion from .welcome_screen import WelcomeScreen, WelcomeScreenChannel @@ -91,6 +107,8 @@ MISSING = utils.MISSING if TYPE_CHECKING: + from typing_extensions import Never + from .abc import Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand from .asset import AssetBytes @@ -130,7 +148,7 @@ class _GuildLimit(NamedTuple): filesize: int -class Guild(Hashable): +class Guild(Hashable[GuildId]): """Represents a Discord guild. This is referred to as a "server" in the official Discord UI. @@ -371,10 +389,10 @@ class Guild(Hashable): } def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None: - self._channels: Dict[int, GuildChannel] = {} - self._members: Dict[int, Member] = {} - self._voice_states: Dict[int, VoiceState] = {} - self._threads: Dict[int, Thread] = {} + self._channels: Dict[ChannelId, GuildChannel] = {} + self._members: Dict[MemberId, Member] = {} + self._voice_states: Dict[MemberId, VoiceState] = {} + self._threads: Dict[ThreadId, Thread] = {} self._stage_instances: Dict[int, StageInstance] = {} self._scheduled_events: Dict[int, GuildScheduledEvent] = {} self._state: ConnectionState = state @@ -383,10 +401,10 @@ def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None: def _add_channel(self, channel: GuildChannel, /) -> None: self._channels[channel.id] = channel - def _remove_channel(self, channel: Snowflake, /) -> None: + def _remove_channel(self, channel: Snowflake[ChannelId], /) -> None: self._channels.pop(channel.id, None) - def _voice_state_for(self, user_id: int, /) -> Optional[VoiceState]: + def _voice_state_for(self, user_id: MemberId, /) -> Optional[VoiceState]: return self._voice_states.get(user_id) def _add_member(self, member: Member, /) -> None: @@ -397,25 +415,25 @@ def _store_thread(self, payload: ThreadPayload, /) -> Thread: self._threads[thread.id] = thread return thread - def _remove_member(self, member: Snowflake, /) -> None: + def _remove_member(self, member: Snowflake[MemberId], /) -> None: self._members.pop(member.id, None) def _add_thread(self, thread: Thread, /) -> None: self._threads[thread.id] = thread - def _remove_thread(self, thread: Snowflake, /) -> None: + def _remove_thread(self, thread: Snowflake[ThreadId], /) -> None: self._threads.pop(thread.id, None) def _clear_threads(self) -> None: self._threads.clear() - def _remove_threads_by_channel(self, channel_id: int) -> None: + def _remove_threads_by_channel(self, channel_id: ChannelId) -> None: to_remove = [k for k, t in self._threads.items() if t.parent_id == channel_id] for k in to_remove: del self._threads[k] - def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]: - to_remove: Dict[int, Thread] = { + def _filter_threads(self, channel_ids: Set[ChannelId]) -> Dict[ThreadId, Thread]: + to_remove: Dict[ThreadId, Thread] = { k: t for k, t in self._threads.items() if t.parent_id in channel_ids } for k in to_remove: @@ -437,9 +455,9 @@ def __repr__(self) -> str: return f"" def _update_voice_state( - self, data: GuildVoiceState, channel_id: Optional[int] + self, data: GuildVoiceState, channel_id: Optional[ChannelId] ) -> Tuple[Optional[Member], VoiceState, VoiceState]: - user_id = int(data["user_id"]) + user_id = MemberId(int(data["user_id"])) channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore try: # check if we should remove the voice state from cache @@ -476,7 +494,7 @@ def _add_role(self, role: Role, /) -> None: self._roles[role.id] = role - def _remove_role(self, role_id: int, /) -> Role: + def _remove_role(self, role_id: RoleId, /) -> Role: # this raises KeyError if it fails.. role = self._roles.pop(role_id) @@ -488,12 +506,13 @@ def _remove_role(self, role_id: int, /) -> Role: return role - def get_command(self, application_command_id: int, /) -> Optional[APIApplicationCommand]: + @overload_get + def get_command(self, id: ApplicationCommandId, /) -> Optional[APIApplicationCommand]: """Gets a cached application command matching the specified ID. Parameters ---------- - application_command_id: :class:`int` + id: :class:`int` The application command ID to search for. Returns @@ -501,7 +520,7 @@ def get_command(self, application_command_id: int, /) -> Optional[APIApplication Optional[Union[:class:`.APIUserCommand`, :class:`.APIMessageCommand`, :class:`.APISlashCommand`]] The application command if found, or ``None`` otherwise. """ - return self._state._get_guild_application_command(self.id, application_command_id) + return self._state._get_guild_application_command(self.id, id) def get_command_named(self, name: str, /) -> Optional[APIApplicationCommand]: """Gets a cached application command matching the specified name. @@ -540,8 +559,8 @@ def _from_data(self, guild: GuildPayload) -> None: self._icon: Optional[str] = guild.get("icon") self._banner: Optional[str] = guild.get("banner") self.unavailable: bool = guild.get("unavailable", False) - self.id: int = int(guild["id"]) - self._roles: Dict[int, Role] = {} + self.id = GuildId(int(guild["id"])) + self._roles: Dict[RoleId, Role] = {} state = self._state # speed up attribute access for r in guild.get("roles", []): role = Role(guild=self, data=r, state=state) @@ -556,7 +575,9 @@ def _from_data(self, guild: GuildPayload) -> None: ) self.features: List[GuildFeature] = guild.get("features", []) self._splash: Optional[str] = guild.get("splash") - self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id") + self._system_channel_id: Optional[ChannelId] = utils._get_as_snowflake( + guild, "system_channel_id", ChannelId + ) self.description: Optional[str] = guild.get("description") self.max_presences: Optional[int] = guild.get("max_presences") self.max_members: Optional[int] = guild.get("max_members") @@ -569,19 +590,23 @@ def _from_data(self, guild: GuildPayload) -> None: self._system_channel_flags: int = guild.get("system_channel_flags", 0) self.preferred_locale: Locale = try_enum(Locale, guild.get("preferred_locale")) self._discovery_splash: Optional[str] = guild.get("discovery_splash") - self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, "rules_channel_id") - self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake( - guild, "public_updates_channel_id" + self._rules_channel_id: Optional[ChannelId] = utils._get_as_snowflake( + guild, "rules_channel_id", ChannelId + ) + self._public_updates_channel_id: Optional[ChannelId] = utils._get_as_snowflake( + guild, "public_updates_channel_id", ChannelId ) self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self.premium_progress_bar_enabled: bool = guild.get("premium_progress_bar_enabled", False) self.approximate_presence_count: Optional[int] = guild.get("approximate_presence_count") self.approximate_member_count: Optional[int] = guild.get("approximate_member_count") self.widget_enabled: Optional[bool] = guild.get("widget_enabled") - self.widget_channel_id: Optional[int] = utils._get_as_snowflake(guild, "widget_channel_id") + self.widget_channel_id: Optional[ChannelId] = utils._get_as_snowflake( + guild, "widget_channel_id", ChannelId + ) self.vanity_url_code: Optional[str] = guild.get("vanity_url_code") - self._safety_alerts_channel_id: Optional[int] = utils._get_as_snowflake( - guild, "safety_alerts_channel_id" + self._safety_alerts_channel_id: Optional[ChannelId] = utils._get_as_snowflake( + guild, "safety_alerts_channel_id", ChannelId ) stage_instances = guild.get("stage_instances") @@ -609,11 +634,13 @@ def _from_data(self, guild: GuildPayload) -> None: self._sync(guild) self._large: Optional[bool] = None if member_count is None else self._member_count >= 250 - self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id") - self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore + self.owner_id: Optional[MemberId] = utils._get_as_snowflake(guild, "owner_id", MemberId) + self.afk_channel: Optional[VocalGuildChannel] = self.get_channel( + utils._get_as_snowflake(guild, "afk_channel_id") # type: ignore + ) for obj in guild.get("voice_states", []): - self._update_voice_state(obj, utils._get_as_snowflake(obj, "channel_id")) + self._update_voice_state(obj, utils._get_as_snowflake(obj, "channel_id", ChannelId)) # TODO: refactor/remove? def _sync(self, data: GuildPayload) -> None: @@ -783,13 +810,18 @@ def key(t: ByCategoryItem) -> Tuple[Tuple[int, int], List[GuildChannel]]: channels.sort(key=lambda c: (c._sorting_bucket, c.position, c.id)) return as_list - def _resolve_channel(self, id: Optional[int], /) -> Optional[Union[GuildChannel, Thread]]: + def _resolve_channel( + self, id: Union[ChannelId, ThreadId, None], / + ) -> Optional[Union[GuildChannel, Thread]]: if id is None: - return + return None - return self._channels.get(id) or self._threads.get(id) + return self._channels.get(id) or self._threads.get(id) # type: ignore - def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, GuildChannel]]: + @overload_get + def get_channel_or_thread( + self, channel_id: Union[ChannelId, ThreadId], / + ) -> Union[Thread, GuildChannel, None]: """Returns a channel or thread with the given ID. .. versionadded:: 2.0 @@ -804,9 +836,10 @@ def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, Gu Optional[Union[:class:`Thread`, :class:`.abc.GuildChannel`]] The returned channel or thread or ``None`` if not found. """ - return self._channels.get(channel_id) or self._threads.get(channel_id) + return self._channels.get(channel_id) or self._threads.get(channel_id) # type: ignore - def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]: + @overload_get + def get_channel(self, channel_id: ChannelId, /) -> Optional[GuildChannel]: """Returns a channel with the given ID. .. note:: @@ -825,7 +858,8 @@ def get_channel(self, channel_id: int, /) -> Optional[GuildChannel]: """ return self._channels.get(channel_id) - def get_thread(self, thread_id: int, /) -> Optional[Thread]: + @overload_get + def get_thread(self, thread_id: ThreadId, /) -> Optional[Thread]: """Returns a thread with the given ID. .. versionadded:: 2.0 @@ -931,7 +965,8 @@ def members(self) -> List[Member]: """List[:class:`Member`]: A list of members that belong to this guild.""" return list(self._members.values()) - def get_member(self, user_id: int, /) -> Optional[Member]: + @overload_get + def get_member(self, user_id: MemberId, /) -> Optional[Member]: """Returns a member with the given ID. Parameters @@ -960,7 +995,8 @@ def roles(self) -> List[Role]: """ return sorted(self._roles.values()) - def get_role(self, role_id: int, /) -> Optional[Role]: + @overload_get + def get_role(self, role_id: RoleId, /) -> Optional[Role]: """Returns a role with the given ID. Parameters @@ -979,7 +1015,7 @@ def get_role(self, role_id: int, /) -> Optional[Role]: def default_role(self) -> Role: """:class:`Role`: Gets the @everyone role that all members have by default.""" # The @everyone role is *always* given - return self.get_role(self.id) # type: ignore + return self.get_role(self.id) @property def premium_subscriber_role(self) -> Optional[Role]: @@ -1198,7 +1234,7 @@ def _create_channel( name: str, channel_type: ChannelType, overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, **options: Any, ) -> Any: if overwrites is MISSING: @@ -1236,7 +1272,7 @@ async def create_text_channel( name: str, *, reason: Optional[str] = None, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, position: int = MISSING, topic: Optional[str] = MISSING, slowmode_delay: int = MISSING, @@ -1390,7 +1426,7 @@ async def create_voice_channel( self, name: str, *, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, position: int = MISSING, bitrate: int = MISSING, user_limit: int = MISSING, @@ -1514,7 +1550,7 @@ async def create_stage_channel( rtc_region: Optional[Union[str, VoiceRegion]] = MISSING, video_quality_mode: VideoQualityMode = MISSING, overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, nsfw: bool = MISSING, slowmode_delay: int = MISSING, reason: Optional[str] = None, @@ -1635,7 +1671,7 @@ async def create_forum_channel( name: str, *, topic: Optional[str] = None, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, position: int = MISSING, slowmode_delay: int = MISSING, default_thread_slowmode_delay: int = MISSING, @@ -1785,7 +1821,7 @@ async def create_media_channel( name: str, *, topic: Optional[str] = None, - category: Optional[Snowflake] = None, + category: Optional[Snowflake[CategoryId]] = None, position: int = MISSING, slowmode_delay: int = MISSING, default_thread_slowmode_delay: int = MISSING, @@ -2014,7 +2050,7 @@ async def edit( invites_disabled: bool = MISSING, raid_alerts_disabled: bool = MISSING, afk_channel: Optional[VoiceChannel] = MISSING, - owner: Snowflake = MISSING, + owner: Snowflake[MemberId] = MISSING, afk_timeout: int = MISSING, default_notifications: NotificationLevel = MISSING, verification_level: VerificationLevel = MISSING, @@ -2473,15 +2509,14 @@ async def create_scheduled_event( description: str = ..., image: AssetBytes = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... @overload async def create_scheduled_event( self, *, name: str, - channel: Snowflake, + channel: Snowflake[ChannelId], scheduled_start_time: datetime.datetime, entity_type: Literal[ GuildScheduledEventEntityType.voice, @@ -2492,8 +2527,7 @@ async def create_scheduled_event( description: str = ..., image: AssetBytes = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... @overload async def create_scheduled_event( @@ -2509,15 +2543,14 @@ async def create_scheduled_event( description: str = ..., image: AssetBytes = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... async def create_scheduled_event( self, *, name: str, scheduled_start_time: datetime.datetime, - channel: Optional[Snowflake] = MISSING, + channel: Optional[Snowflake[ChannelId]] = MISSING, entity_type: GuildScheduledEventEntityType = MISSING, scheduled_end_time: Optional[datetime.datetime] = MISSING, privacy_level: GuildScheduledEventPrivacyLevel = MISSING, @@ -2826,7 +2859,8 @@ def fetch_members( return MemberIterator(self, limit=limit, after=after) - async def fetch_member(self, member_id: int, /) -> Member: + @overload_fetch + async def fetch_member(self, member_id: MemberId, /) -> Member: """|coro| Retrieves a :class:`Member` with the given ID. @@ -2857,7 +2891,7 @@ async def fetch_member(self, member_id: int, /) -> Member: data = await self._state.http.get_member(self.id, member_id) return Member(data=data, state=self._state, guild=self) - async def fetch_ban(self, user: Snowflake) -> BanEntry: + async def fetch_ban(self, user: Snowflake[UserId]) -> BanEntry: """|coro| Retrieves the :class:`BanEntry` for a user. @@ -2887,7 +2921,15 @@ async def fetch_ban(self, user: Snowflake) -> BanEntry: data: BanPayload = await self._state.http.get_ban(user.id, self.id) return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"]) - async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: + @overload + async def fetch_channel(self, id: ChannelId, /) -> GuildChannel: ... + @overload + async def fetch_channel(self, id: ThreadId, /) -> Thread: ... + @overload + async def fetch_channel(self, id: ObjectId, /) -> Never: ... + @overload + async def fetch_channel(self, id: int, /) -> Union[GuildChannel, Thread]: ... + async def fetch_channel(self, id: int, /) -> Union[GuildChannel, Thread]: """|coro| Retrieves a :class:`.abc.GuildChannel` or :class:`.Thread` with the given ID. @@ -2916,7 +2958,7 @@ async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread] Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel from the ID. """ - data = await self._state.http.get_channel(channel_id) + data = await self._state.http.get_channel(id) factory, ch_type = _threaded_guild_channel_factory(data["type"]) if factory is None: @@ -2993,7 +3035,7 @@ async def prune_members( *, days: int, compute_prune_count: bool = True, - roles: List[Snowflake] = MISSING, + roles: List[Snowflake[RoleId]] = MISSING, reason: Optional[str] = None, ) -> Optional[int]: """|coro| @@ -3110,7 +3152,9 @@ async def webhooks(self) -> List[Webhook]: data = await self._state.http.guild_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int: + async def estimate_pruned_members( + self, *, days: int, roles: List[Snowflake[RoleId]] = MISSING + ) -> int: """|coro| Similar to :meth:`prune_members` except instead of actually @@ -3310,7 +3354,8 @@ async def fetch_stickers(self) -> List[GuildSticker]: data = await self._state.http.get_all_guild_stickers(self.id) return [GuildSticker(state=self._state, data=d) for d in data] - async def fetch_sticker(self, sticker_id: int, /) -> GuildSticker: + @overload_fetch + async def fetch_sticker(self, sticker_id: StickerId, /) -> GuildSticker: """|coro| Retrieves a custom :class:`Sticker` from the guild. @@ -3400,7 +3445,9 @@ async def create_sticker( data = await self._state.http.create_guild_sticker(self.id, payload, file, reason=reason) return self._state.store_sticker(self, data) - async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None: + async def delete_sticker( + self, sticker: Snowflake[StickerId], *, reason: Optional[str] = None + ) -> None: """|coro| Deletes the custom :class:`Sticker` from the guild. @@ -3448,7 +3495,8 @@ async def fetch_emojis(self) -> List[Emoji]: data = await self._state.http.get_all_custom_emojis(self.id) return [Emoji(guild=self, state=self._state, data=d) for d in data] - async def fetch_emoji(self, emoji_id: int, /) -> Emoji: + @overload_fetch + async def fetch_emoji(self, emoji_id: EmojiId, /) -> Emoji: """|coro| Retrieves a custom :class:`Emoji` from the guild. @@ -3557,7 +3605,9 @@ async def create_custom_emoji( ) return self._state.store_emoji(self, data) - async def delete_emoji(self, emoji: Snowflake, *, reason: Optional[str] = None) -> None: + async def delete_emoji( + self, emoji: Snowflake[EmojiId], *, reason: Optional[str] = None + ) -> None: """|coro| Deletes the custom :class:`Emoji` from the guild. @@ -3607,16 +3657,16 @@ async def fetch_roles(self) -> List[Role]: @overload async def get_or_fetch_member( - self, member_id: int, *, strict: Literal[False] = ... - ) -> Optional[Member]: - ... + self, member_id: MemberId, *, strict: Literal[False] = ... + ) -> Optional[Member]: ... @overload - async def get_or_fetch_member(self, member_id: int, *, strict: Literal[True]) -> Member: - ... + async def get_or_fetch_member( + self, member_id: MemberId, *, strict: Literal[True] + ) -> Member: ... async def get_or_fetch_member( - self, member_id: int, *, strict: bool = False + self, member_id: MemberId, *, strict: bool = False ) -> Optional[Member]: """|coro| @@ -3668,8 +3718,7 @@ async def create_role( icon: AssetBytes = ..., emoji: str = ..., mentionable: bool = ..., - ) -> Role: - ... + ) -> Role: ... @overload async def create_role( @@ -3683,8 +3732,7 @@ async def create_role( icon: AssetBytes = ..., emoji: str = ..., mentionable: bool = ..., - ) -> Role: - ... + ) -> Role: ... async def create_role( self, @@ -3791,7 +3839,7 @@ async def create_role( return role async def edit_role_positions( - self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None + self, positions: Dict[Snowflake[RoleId], int], *, reason: Optional[str] = None ) -> List[Role]: """|coro| @@ -3857,7 +3905,7 @@ async def edit_role_positions( return roles - async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None: + async def kick(self, user: Snowflake[MemberId], *, reason: Optional[str] = None) -> None: """|coro| Kicks a user from the guild. @@ -3886,26 +3934,24 @@ async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None: @overload async def ban( self, - user: Snowflake, + user: Snowflake[UserId], *, clean_history_duration: Union[int, datetime.timedelta] = 86400, reason: Optional[str] = None, - ) -> None: - ... + ) -> None: ... @overload async def ban( self, - user: Snowflake, + user: Snowflake[UserId], *, delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, reason: Optional[str] = None, - ) -> None: - ... + ) -> None: ... async def ban( self, - user: Snowflake, + user: Snowflake[UserId], *, clean_history_duration: Union[int, datetime.timedelta] = MISSING, delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = MISSING, @@ -3986,7 +4032,7 @@ async def ban( user.id, self.id, delete_message_seconds=delete_message_seconds, reason=reason ) - async def unban(self, user: Snowflake, *, reason: Optional[str] = None) -> None: + async def unban(self, user: Snowflake[UserId], *, reason: Optional[str] = None) -> None: """|coro| Unbans a user from the guild. @@ -4014,7 +4060,7 @@ async def unban(self, user: Snowflake, *, reason: Optional[str] = None) -> None: async def bulk_ban( self, - users: Iterable[Snowflake], + users: Iterable[Snowflake[UserId]], *, clean_history_duration: Union[int, datetime.timedelta] = 0, reason: Optional[str] = None, @@ -4151,7 +4197,7 @@ def audit_logs( limit: Optional[int] = 100, before: Optional[SnowflakeTime] = None, after: Optional[SnowflakeTime] = None, - user: Optional[Snowflake] = None, + user: Optional[Snowflake[UserId]] = None, action: Optional[AuditLogAction] = None, oldest_first: bool = False, ) -> AuditLogIterator: @@ -4280,7 +4326,7 @@ async def edit_widget( self, *, enabled: bool = MISSING, - channel: Optional[Snowflake] = MISSING, + channel: Optional[Snowflake[ChannelId]] = MISSING, reason: Optional[str] = None, ) -> WidgetSettings: """|coro| @@ -4643,7 +4689,11 @@ async def fetch_voice_regions(self) -> List[VoiceRegion]: return [VoiceRegion(data=region) for region in data] async def change_voice_state( - self, *, channel: Optional[Snowflake], self_mute: bool = False, self_deaf: bool = False + self, + *, + channel: Optional[Snowflake[ChannelId]], + self_mute: bool = False, + self_deaf: bool = False, ) -> None: """|coro| @@ -4675,8 +4725,9 @@ async def bulk_fetch_command_permissions(self) -> List[GuildApplicationCommandPe """ return await self._state.bulk_fetch_command_permissions(self.id) + @overload_fetch async def fetch_command_permissions( - self, command_id: int + self, command_id: Union[ApplicationCommandId, ApplicationId] ) -> GuildApplicationCommandPermissions: """|coro| @@ -4702,26 +4753,24 @@ async def fetch_command_permissions( @overload async def timeout( self, - user: Snowflake, + user: Snowflake[MemberId], *, duration: Optional[Union[float, datetime.timedelta]], reason: Optional[str] = None, - ) -> Member: - ... + ) -> Member: ... @overload async def timeout( self, - user: Snowflake, + user: Snowflake[MemberId], *, until: Optional[datetime.datetime], reason: Optional[str] = None, - ) -> Member: - ... + ) -> Member: ... async def timeout( self, - user: Snowflake, + user: Snowflake[MemberId], *, duration: Optional[Union[float, datetime.timedelta]] = MISSING, until: Optional[datetime.datetime] = MISSING, @@ -4848,8 +4897,8 @@ async def create_automod_rule( actions: Sequence[AutoModAction], trigger_metadata: Optional[AutoModTriggerMetadata] = None, enabled: bool = False, - exempt_roles: Optional[Sequence[Snowflake]] = None, - exempt_channels: Optional[Sequence[Snowflake]] = None, + exempt_roles: Optional[Sequence[Snowflake[RoleId]]] = None, + exempt_channels: Optional[Sequence[Snowflake[ChannelId]]] = None, reason: Optional[str] = None, ) -> AutoModRule: """|coro| diff --git a/disnake/guild_preview.py b/disnake/guild_preview.py index e6c7ec81b0..154bcfda79 100644 --- a/disnake/guild_preview.py +++ b/disnake/guild_preview.py @@ -7,6 +7,7 @@ from .asset import Asset from .emoji import Emoji from .sticker import GuildSticker +from .types.ids import GuildId if TYPE_CHECKING: from .state import ConnectionState @@ -61,7 +62,7 @@ class GuildPreview: def __init__(self, *, data: GuildPreviewPayload, state: ConnectionState) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id: GuildId = GuildId(int(data["id"])) self.name: str = data["name"] self.approximate_member_count: int = data["approximate_member_count"] self.approximate_presence_count: int = data["approximate_presence_count"] diff --git a/disnake/guild_scheduled_event.py b/disnake/guild_scheduled_event.py index 1b01be136c..3fec5457ec 100644 --- a/disnake/guild_scheduled_event.py +++ b/disnake/guild_scheduled_event.py @@ -14,6 +14,7 @@ try_enum, ) from .mixins import Hashable +from .types.ids import ChannelId, GuildId, MemberId from .utils import ( MISSING, _assetbytes_to_base64_data, @@ -157,9 +158,9 @@ def __init__(self, *, state: ConnectionState, data: GuildScheduledEventPayload) def _update(self, data: GuildScheduledEventPayload) -> None: self.id: int = int(data["id"]) - self.guild_id: int = int(data["guild_id"]) - self.channel_id: Optional[int] = _get_as_snowflake(data, "channel_id") - self.creator_id: Optional[int] = _get_as_snowflake(data, "creator_id") + self.guild_id: GuildId = GuildId(int(data["guild_id"])) + self.channel_id: Optional[ChannelId] = _get_as_snowflake(data, "channel_id", ChannelId) + self.creator_id: Optional[MemberId] = _get_as_snowflake(data, "creator_id", MemberId) self.name: str = data["name"] self.description: Optional[str] = data.get("description") self.scheduled_start_time: datetime = parse_time(data["scheduled_start_time"]) @@ -278,15 +279,14 @@ async def edit( name: str = ..., description: Optional[str] = ..., image: Optional[AssetBytes] = ..., - channel: Optional[Snowflake] = ..., + channel: Optional[Snowflake[ChannelId]] = ..., privacy_level: GuildScheduledEventPrivacyLevel = ..., scheduled_start_time: datetime = ..., scheduled_end_time: Optional[datetime] = ..., entity_metadata: Optional[GuildScheduledEventMetadata] = ..., status: GuildScheduledEventStatus = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... # new entity_type is `external`, no channel @overload @@ -304,8 +304,7 @@ async def edit( entity_metadata: GuildScheduledEventMetadata = ..., status: GuildScheduledEventStatus = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... # new entity_type is `voice` or `stage_instance`, no entity_metadata @overload @@ -319,14 +318,13 @@ async def edit( name: str = ..., description: Optional[str] = ..., image: Optional[AssetBytes] = ..., - channel: Snowflake = ..., + channel: Snowflake[ChannelId] = ..., privacy_level: GuildScheduledEventPrivacyLevel = ..., scheduled_start_time: datetime = ..., scheduled_end_time: Optional[datetime] = ..., status: GuildScheduledEventStatus = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... # channel=None, no entity_type @overload @@ -343,15 +341,14 @@ async def edit( entity_metadata: GuildScheduledEventMetadata = ..., status: GuildScheduledEventStatus = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... # valid channel, no entity_type @overload async def edit( self, *, - channel: Snowflake, + channel: Snowflake[ChannelId], name: str = ..., description: Optional[str] = ..., image: Optional[AssetBytes] = ..., @@ -360,8 +357,7 @@ async def edit( scheduled_end_time: Optional[datetime] = ..., status: GuildScheduledEventStatus = ..., reason: Optional[str] = ..., - ) -> GuildScheduledEvent: - ... + ) -> GuildScheduledEvent: ... async def edit( self, @@ -369,7 +365,7 @@ async def edit( name: str = MISSING, description: Optional[str] = MISSING, image: Optional[AssetBytes] = MISSING, - channel: Optional[Snowflake] = MISSING, + channel: Optional[Snowflake[ChannelId]] = MISSING, privacy_level: GuildScheduledEventPrivacyLevel = MISSING, scheduled_start_time: datetime = MISSING, scheduled_end_time: Optional[datetime] = MISSING, diff --git a/disnake/invite.py b/disnake/invite.py index 545159dfb8..18c17c9554 100644 --- a/disnake/invite.py +++ b/disnake/invite.py @@ -8,8 +8,9 @@ from .asset import Asset from .enums import ChannelType, InviteTarget, InviteType, NSFWLevel, VerificationLevel, try_enum from .guild_scheduled_event import GuildScheduledEvent -from .mixins import Hashable +from .mixins import EqualityComparable from .object import Object +from .types.ids import ChannelId, GuildId from .utils import _get_as_snowflake, parse_time, snowflake_time from .welcome_screen import WelcomeScreen @@ -91,7 +92,7 @@ class PartialInviteChannel: def __init__(self, *, state: ConnectionState, data: InviteChannelPayload) -> None: self._state = state - self.id: int = int(data["id"]) + self.id: ChannelId = ChannelId(int(data["id"])) self.name: Optional[str] = data.get("name") self.type: ChannelType = try_enum(ChannelType, data["type"]) if self.type is ChannelType.group: @@ -198,9 +199,9 @@ class PartialInviteGuild: "premium_subscription_count", ) - def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int) -> None: + def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: GuildId) -> None: self._state: ConnectionState = state - self.id: int = id + self.id: GuildId = id self.name: str = data["name"] self.features: List[GuildFeature] = data.get("features", []) self._icon: Optional[str] = data.get("icon") @@ -250,7 +251,7 @@ def splash(self) -> Optional[Asset]: return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes") -class Invite(Hashable): +class Invite(EqualityComparable): """Represents a Discord :class:`Guild` or :class:`abc.GuildChannel` invite. Depending on the way this object was created, some of the attributes can @@ -435,7 +436,9 @@ def __init__( ) inviter_data = data.get("inviter") - self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) # type: ignore + self.inviter: Optional[User] = ( + None if inviter_data is None else self._state.create_user(inviter_data) + ) # type: ignore self.channel: Optional[InviteChannelType] = self._resolve_channel( data.get("channel"), channel @@ -457,7 +460,9 @@ def __init__( self.guild_welcome_screen: Optional[WelcomeScreen] = None target_user_data = data.get("target_user") - self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data) # type: ignore + self.target_user: Optional[User] = ( + None if target_user_data is None else self._state.create_user(target_user_data) + ) # type: ignore self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) @@ -482,7 +487,7 @@ def from_incomplete(cls, *, state: ConnectionState, data: InvitePayload) -> Self # If we're here, then this is a group DM guild = None else: - guild_id = int(guild_data["id"]) + guild_id = GuildId(int(guild_data["id"])) guild = state._get_guild(guild_id) if guild is None: # If it's not cached, then it has to be a partial guild @@ -499,9 +504,9 @@ def from_incomplete(cls, *, state: ConnectionState, data: InvitePayload) -> Self @classmethod def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self: - guild_id: Optional[int] = _get_as_snowflake(data, "guild_id") - guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) - channel_id = int(data["channel_id"]) + guild_id = _get_as_snowflake(data, "guild_id", GuildId) + guild: Union[Guild, Object, None] = state._get_guild(guild_id) + channel_id = ChannelId(int(data["channel_id"])) if guild is not None: channel = guild.get_channel(channel_id) or Object(id=channel_id) else: @@ -527,7 +532,7 @@ def _resolve_guild( if data is None: return None - guild_id = int(data["id"]) + guild_id = GuildId(int(data["id"])) return PartialInviteGuild(self._state, data, guild_id) def _resolve_channel( diff --git a/disnake/member.py b/disnake/member.py index 149fc97ecc..7b07e903b6 100644 --- a/disnake/member.py +++ b/disnake/member.py @@ -31,6 +31,7 @@ from .flags import MemberFlags from .object import Object from .permissions import Permissions +from .types.ids import MemberId, RoleId, overload_get from .user import BaseUser, User, _UserTag from .utils import MISSING @@ -279,7 +280,7 @@ class Member(disnake.abc.Messageable, _UserTag): if TYPE_CHECKING: name: str - id: int + id: MemberId discriminator: str global_name: Optional[str] bot: bool @@ -302,8 +303,7 @@ def __init__( data: Union[MemberWithUserPayload, GuildMemberUpdateEvent], guild: Guild, state: ConnectionState, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -313,8 +313,7 @@ def __init__( guild: Guild, state: ConnectionState, user_data: UserPayload, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -334,7 +333,7 @@ def __init__( self.premium_since: Optional[datetime.datetime] = utils.parse_time( data.get("premium_since") ) - self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) + self._roles: utils.SnowflakeList[RoleId] = utils.SnowflakeList(map(int, data["roles"])) self._client_status: Dict[Optional[str], str] = {None: "offline"} self.activities: Tuple[ActivityTypes, ...] = () self.nick: Optional[str] = data.get("nick") @@ -356,10 +355,10 @@ def __repr__(self) -> str: f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _UserTag) and other.id == self.id - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -447,7 +446,8 @@ def _presence_update( ) -> Optional[Tuple[User, User]]: self.activities = tuple(create_activity(a, state=self._state) for a in data["activities"]) self._client_status = { - sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore + sys.intern(key): sys.intern(value) + for key, value in data.get("client_status", {}).items() # type: ignore } self._client_status[None] = sys.intern(data["status"]) @@ -787,8 +787,7 @@ async def ban( *, clean_history_duration: Union[int, datetime.timedelta] = 86400, reason: Optional[str] = None, - ) -> None: - ... + ) -> None: ... @overload async def ban( @@ -796,8 +795,7 @@ async def ban( *, delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, reason: Optional[str] = None, - ) -> None: - ... + ) -> None: ... async def ban( self, @@ -838,7 +836,7 @@ async def edit( mute: bool = MISSING, deafen: bool = MISSING, suppress: bool = MISSING, - roles: Sequence[disnake.abc.Snowflake] = MISSING, + roles: Sequence[Snowflake[RoleId]] = MISSING, voice_channel: Optional[VocalGuildChannel] = MISSING, timeout: Optional[Union[float, datetime.timedelta, datetime.datetime]] = MISSING, flags: MemberFlags = MISSING, @@ -1063,7 +1061,7 @@ async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = N await self.edit(voice_channel=channel, reason=reason) async def add_roles( - self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True + self, *roles: Snowflake[RoleId], reason: Optional[str] = None, atomic: bool = True ) -> None: """|coro| @@ -1103,7 +1101,7 @@ async def add_roles( await req(guild_id, user_id, role.id, reason=reason) async def remove_roles( - self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True + self, *roles: Snowflake[RoleId], reason: Optional[str] = None, atomic: bool = True ) -> None: """|coro| @@ -1148,7 +1146,8 @@ async def remove_roles( for role in roles: await req(guild_id, user_id, role.id, reason=reason) - def get_role(self, role_id: int, /) -> Optional[Role]: + @overload_get + def get_role(self, role_id: RoleId, /) -> Optional[Role]: """Returns a role with the given ID from roles which the member has. .. versionadded:: 2.0 @@ -1171,8 +1170,7 @@ async def timeout( *, duration: Optional[Union[float, datetime.timedelta]], reason: Optional[str] = None, - ) -> Member: - ... + ) -> Member: ... @overload async def timeout( @@ -1180,8 +1178,7 @@ async def timeout( *, until: Optional[datetime.datetime], reason: Optional[str] = None, - ) -> Member: - ... + ) -> Member: ... async def timeout( self, diff --git a/disnake/mentions.py b/disnake/mentions.py index 7faec1f250..030b2b1698 100644 --- a/disnake/mentions.py +++ b/disnake/mentions.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, List, Union from .enums import MessageType +from .types.ids import RoleId, UserId __all__ = ("AllowedMentions",) @@ -66,8 +67,8 @@ def __init__( self, *, everyone: bool = default, - users: Union[bool, List[Snowflake]] = default, - roles: Union[bool, List[Snowflake]] = default, + users: Union[bool, List[Snowflake[UserId]]] = default, + roles: Union[bool, List[Snowflake[RoleId]]] = default, replied_user: bool = default, ) -> None: self.everyone = everyone diff --git a/disnake/message.py b/disnake/message.py index f67da56a28..2efbe7e498 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -38,6 +38,19 @@ from .reaction import Reaction from .sticker import StickerItem from .threads import Thread +from .types.ids import ( + ApplicationId, + AttachmentId, + ChannelId, + GuildId, + InteractionId, + MemberId, + MessageId, + PrivateChannelId, + ThreadId, + UserId, + WebhookId, +) from .ui.action_row import components_to_dict from .user import User from .utils import MISSING, assert_never, escape_mentions @@ -314,7 +327,7 @@ class Attachment(Hashable): ) def __init__(self, *, data: AttachmentPayload, state: ConnectionState) -> None: - self.id: int = int(data["id"]) + self.id: AttachmentId = AttachmentId(int(data["id"])) self.size: int = data["size"] self.height: Optional[int] = data.get("height") self.width: Optional[int] = data.get("width") @@ -598,24 +611,24 @@ class MessageReference: def __init__( self, *, - message_id: int, - channel_id: int, - guild_id: Optional[int] = None, + message_id: MessageId, + channel_id: Union[ChannelId, PrivateChannelId, ThreadId], + guild_id: Optional[GuildId] = None, fail_if_not_exists: bool = True, ) -> None: self._state: Optional[ConnectionState] = None self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None - self.message_id: Optional[int] = message_id - self.channel_id: int = channel_id - self.guild_id: Optional[int] = guild_id + self.message_id: Optional[MessageId] = message_id + self.channel_id: Union[ChannelId, PrivateChannelId, ThreadId] = channel_id + self.guild_id: Optional[GuildId] = guild_id self.fail_if_not_exists: bool = fail_if_not_exists @classmethod def with_state(cls, state: ConnectionState, data: MessageReferencePayload) -> Self: self = cls.__new__(cls) - self.message_id = utils._get_as_snowflake(data, "message_id") - self.channel_id = int(data["channel_id"]) - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.message_id = utils._get_as_snowflake(data, "message_id", MessageId) + self.channel_id = ChannelId(int(data["channel_id"])) + self.guild_id = utils._get_as_snowflake(data, "guild_id", GuildId) self.fail_if_not_exists = data.get("fail_if_not_exists", True) self._state = state self.resolved = None @@ -721,14 +734,14 @@ def __init__( guild: Optional[Guild], data: InteractionMessageReferencePayload, ) -> None: - self.id: int = int(data["id"]) + self.id: InteractionId = InteractionId(int(data["id"])) self.type: InteractionType = try_enum(InteractionType, int(data["type"])) self.name: str = data["name"] user: Optional[Union[User, Member]] = None if guild: if isinstance(guild, Guild): # this can be a placeholder object in interactions - user = guild.get_member(int(data["user"]["id"])) + user = guild.get_member(MemberId(int(data["user"]["id"]))) # If not cached, try data from event. # This is only available via gateway (message_create/_edit), not HTTP @@ -799,7 +812,7 @@ def flatten_handlers(cls): @flatten_handlers -class Message(Hashable): +class Message(Hashable[MessageId]): """Represents a message from Discord. .. collapse:: operations @@ -993,9 +1006,13 @@ def __init__( data: MessagePayload, ) -> None: self._state: ConnectionState = state - self.id: int = int(data["id"]) - self.application_id: Optional[int] = utils._get_as_snowflake(data, "application_id") - self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id") + self.id = MessageId(int(data["id"])) + self.application_id: Optional[ApplicationId] = utils._get_as_snowflake( + data, "application_id", ApplicationId + ) + self.webhook_id: Optional[WebhookId] = utils._get_as_snowflake( + data, "webhook_id", WebhookId + ) self.reactions: List[Reaction] = [ Reaction(message=self, data=d) for d in data.get("reactions", []) ] @@ -1035,7 +1052,7 @@ def __init__( # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: - self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) + self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id", GuildId)) self.interaction: Optional[InteractionReference] = ( InteractionReference(state=state, guild=self.guild, data=interaction) @@ -1096,7 +1113,7 @@ def _try_patch(self, data, key, transform=None) -> None: setattr(self, key, transform(value)) def _add_reaction( - self, data: MessageReactionAddEvent, emoji: EmojiInputType, user_id: int + self, data: MessageReactionAddEvent, emoji: EmojiInputType, user_id: UserId ) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) is_me = user_id == self._state.self_id @@ -1117,7 +1134,7 @@ def _add_reaction( return reaction def _remove_reaction( - self, data: MessageReactionRemoveEvent, emoji: EmojiInputType, user_id: int + self, data: MessageReactionRemoveEvent, emoji: EmojiInputType, user_id: UserId ) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) @@ -1645,8 +1662,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -1662,8 +1678,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -1679,8 +1694,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -1696,8 +1710,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... async def edit( self, @@ -1977,7 +1990,7 @@ async def add_reaction(self, emoji: EmojiInputType) -> None: await self._state.http.add_reaction(self.channel.id, self.id, emoji) async def remove_reaction( - self, emoji: Union[EmojiInputType, Reaction], member: Snowflake + self, emoji: Union[EmojiInputType, Reaction], member: Snowflake[UserId] ) -> None: """|coro| @@ -2279,7 +2292,7 @@ class PartialMessage(Hashable): to_reference = Message.to_reference to_message_reference_dict = Message.to_message_reference_dict - def __init__(self, *, channel: MessageableChannel, id: int) -> None: + def __init__(self, *, channel: MessageableChannel, id: MessageId) -> None: if channel.type not in ( ChannelType.text, ChannelType.news, @@ -2297,7 +2310,7 @@ def __init__(self, *, channel: MessageableChannel, id: int) -> None: self.channel: MessageableChannel = channel self._state: ConnectionState = channel._state - self.id: int = id + self.id: MessageId = id def _update(self, data) -> None: # This is used for duck typing purposes. @@ -2357,8 +2370,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -2374,8 +2386,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -2391,8 +2402,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... @overload async def edit( @@ -2408,8 +2418,7 @@ async def edit( view: Optional[View] = ..., components: Optional[Components[MessageUIComponent]] = ..., delete_after: Optional[float] = ..., - ) -> Message: - ... + ) -> Message: ... async def edit( self, diff --git a/disnake/mixins.py b/disnake/mixins.py index 59e5861a87..58dc333e8e 100644 --- a/disnake/mixins.py +++ b/disnake/mixins.py @@ -1,26 +1,75 @@ # SPDX-License-Identifier: MIT +from typing import TYPE_CHECKING, Generic, Literal, overload + +from .types.ids import ( + ApplicationCommandId, + ApplicationId, + AttachmentId, + ChannelId, + EmojiId, + GuildId, + InteractionId, + MessageId, + PrivateChannelId, + RoleId, + StickerId, + ThreadId, + UserId, + WebhookId, +) + +if TYPE_CHECKING: + from typing_extensions import Self, TypeVar + + IdT = TypeVar( + "IdT", + ApplicationCommandId, + ApplicationId, + AttachmentId, + ChannelId, + EmojiId, + GuildId, + InteractionId, + MessageId, + PrivateChannelId, + RoleId, + StickerId, + ThreadId, + UserId, + WebhookId, + int, + infer_variance=True, + default=int, + ) + __all__ = ( "EqualityComparable", "Hashable", ) -class EqualityComparable: +class EqualityComparable(Generic[IdT]): __slots__ = () - id: int + id: IdT + @overload + def __eq__(self, other: "Self") -> bool: ... + @overload + def __eq__(self, other: object) -> Literal[False]: ... def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and other.id == self.id + return isinstance(other, self.__class__) and self.id == other.id + @overload + def __ne__(self, other: "Self") -> bool: ... + @overload + def __ne__(self, other: object) -> Literal[True]: ... def __ne__(self, other: object) -> bool: - if isinstance(other, self.__class__): - return other.id != self.id - return True + return not isinstance(other, self.__class__) or self.id != other.id -class Hashable(EqualityComparable): +class Hashable(EqualityComparable[IdT]): __slots__ = () def __hash__(self) -> int: diff --git a/disnake/object.py b/disnake/object.py index cd3048b6b1..6ccde69253 100644 --- a/disnake/object.py +++ b/disnake/object.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, SupportsInt, Union +from typing import TYPE_CHECKING, SupportsInt, Union, overload from . import utils from .mixins import Hashable @@ -10,12 +10,14 @@ if TYPE_CHECKING: import datetime + from .mixins import IdT + SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] __all__ = ("Object",) -class Object(Hashable): +class Object(Hashable["IdT"]): """Represents a generic Discord object. The purpose of this class is to allow you to create 'miniature' @@ -49,7 +51,11 @@ class Object(Hashable): The ID of the object. """ - def __init__(self, id: SupportsIntCast) -> None: + @overload + def __init__(self, id: "IdT") -> None: ... + @overload + def __init__(self, id: SupportsIntCast) -> None: ... + def __init__(self, id: Union["IdT", SupportsIntCast]) -> None: try: id = int(id) except ValueError: @@ -57,7 +63,7 @@ def __init__(self, id: SupportsIntCast) -> None: f"id parameter must be convertable to int not {id.__class__!r}" ) from None else: - self.id = id + self.id: IdT = id # type: ignore def __repr__(self) -> str: return f"" diff --git a/disnake/partial_emoji.py b/disnake/partial_emoji.py index 92656bb314..66d37cd05e 100644 --- a/disnake/partial_emoji.py +++ b/disnake/partial_emoji.py @@ -7,6 +7,7 @@ from . import utils from .asset import Asset, AssetMixin +from .types.ids import EmojiId __all__ = ("PartialEmoji",) @@ -24,7 +25,7 @@ class _EmojiTag: __slots__ = () - id: int + id: EmojiId def _to_partial(self) -> PartialEmoji: raise NotImplementedError @@ -75,12 +76,12 @@ class PartialEmoji(_EmojiTag, AssetMixin): ) if TYPE_CHECKING: - id: Optional[int] + id: Optional[EmojiId] def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None) -> None: self.animated = animated self.name = name - self.id = id + self.id = id # type: ignore self._state = None @classmethod @@ -89,7 +90,7 @@ def from_dict( ) -> Self: return cls( animated=data.get("animated", False), - id=utils._get_as_snowflake(data, "id"), + id=utils._get_as_snowflake(data, "id", EmojiId), name=data.get("name") or "", ) @@ -147,7 +148,7 @@ def with_state( *, name: str, animated: bool = False, - id: Optional[int] = None, + id: Optional[EmojiId] = None, ) -> Self: self = cls(name=name, animated=animated, id=id) self._state = state @@ -165,7 +166,7 @@ def __repr__(self) -> str: f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>" ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if self.is_unicode_emoji(): return isinstance(other, PartialEmoji) and self.name == other.name @@ -173,7 +174,7 @@ def __eq__(self, other: Any) -> bool: return self.id == other.id return False - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -254,8 +255,8 @@ async def read(self) -> bytes: # (e.g. default reaction, tag emoji) @staticmethod def _emoji_to_name_id( - emoji: Optional[Union[str, Emoji, PartialEmoji]] - ) -> Tuple[Optional[str], Optional[int]]: + emoji: Optional[Union[str, Emoji, PartialEmoji]], + ) -> Tuple[Optional[str], Optional[EmojiId]]: if emoji is None: return None, None diff --git a/disnake/reaction.py b/disnake/reaction.py index 0720759f6a..70811c4c45 100644 --- a/disnake/reaction.py +++ b/disnake/reaction.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from .iterators import ReactionIterator +from .types.ids import UserId __all__ = ("Reaction",) @@ -96,7 +97,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"" - async def remove(self, user: Snowflake) -> None: + async def remove(self, user: Snowflake[UserId]) -> None: """|coro| Removes the reaction by the provided :class:`User` from the message. diff --git a/disnake/role.py b/disnake/role.py index 89fa55804f..1ccb4e5dec 100644 --- a/disnake/role.py +++ b/disnake/role.py @@ -10,6 +10,7 @@ from .mixins import Hashable from .partial_emoji import PartialEmoji from .permissions import Permissions +from .types.ids import RoleId from .utils import MISSING, _assetbytes_to_base64_data, _get_as_snowflake, snowflake_time __all__ = ( @@ -224,7 +225,7 @@ class Role(Hashable): def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload) -> None: self.guild: Guild = guild self._state: ConnectionState = state - self.id: int = int(data["id"]) + self.id: RoleId = RoleId(int(data["id"])) self._update(data) def __str__(self) -> str: diff --git a/disnake/state.py b/disnake/state.py index c3263976c6..2243976bde 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -86,6 +86,17 @@ from .stage_instance import StageInstance from .sticker import GuildSticker from .threads import Thread, ThreadMember +from .types.ids import ( + ApplicationCommandId, + ApplicationId, + ChannelId, + EmojiId, + GuildId, + MessageId, + StickerId, + ThreadId, + UserId, +) from .ui.modal import Modal, ModalStore from .ui.view import View, ViewStore from .user import ClientUser, User @@ -119,13 +130,13 @@ class ChunkRequest: def __init__( self, - guild_id: int, + guild_id: GuildId, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True, ) -> None: - self.guild_id: int = guild_id + self.guild_id: GuildId = guild_id self.resolver: Callable[[int], Any] = resolver self.loop: asyncio.AbstractEventLoop = loop self.cache: bool = cache @@ -221,7 +232,9 @@ def __init__( self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None - self.application_id: Optional[int] = None if application_id is None else int(application_id) + self.application_id: Optional[ApplicationId] = ( + None if application_id is None else ApplicationId(int(application_id)) + ) self.heartbeat_timeout: float = heartbeat_timeout self.guild_ready_timeout: float = guild_ready_timeout if self.guild_ready_timeout < 0: @@ -297,14 +310,18 @@ def clear( # However, using weakrefs here unfortunately has a few drawbacks: # - the weakref slot + object in user objects likely results in a small increase in memory usage # - accesses on `_users` are slower, e.g. `__getitem__` takes ~1us with weakrefs and ~0.2us without - self._users: weakref.WeakValueDictionary[int, User] = weakref.WeakValueDictionary() - self._emojis: Dict[int, Emoji] = {} - self._stickers: Dict[int, GuildSticker] = {} - self._guilds: Dict[int, Guild] = {} + self._users: weakref.WeakValueDictionary[UserId, User] = weakref.WeakValueDictionary() + self._emojis: Dict[EmojiId, Emoji] = {} + self._stickers: Dict[StickerId, GuildSticker] = {} + self._guilds: Dict[GuildId, Guild] = {} if application_commands: - self._global_application_commands: Dict[int, APIApplicationCommand] = {} - self._guild_application_commands: Dict[int, Dict[int, APIApplicationCommand]] = {} + self._global_application_commands: Dict[ + ApplicationCommandId, APIApplicationCommand + ] = {} + self._guild_application_commands: Dict[ + GuildId, Dict[ApplicationCommandId, APIApplicationCommand] + ] = {} if views: self._view_store: ViewStore = ViewStore(self) @@ -312,19 +329,19 @@ def clear( if modals: self._modal_store: ModalStore = ModalStore(self) - self._voice_clients: Dict[int, VoiceProtocol] = {} + self._voice_clients: Dict[GuildId, VoiceProtocol] = {} # LRU of max size 128 self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict() # extra dict to look up private channels by user id - self._private_channels_by_user: Dict[int, DMChannel] = {} + self._private_channels_by_user: Dict[UserId, DMChannel] = {} if self.max_messages is not None: self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages) else: self._messages: Optional[Deque[Message]] = None def process_chunk_requests( - self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool + self, guild_id: GuildId, nonce: Optional[str], members: List[Member], complete: bool ) -> None: removed = [] for key, request in self._chunk_requests.items(): @@ -368,14 +385,14 @@ def intents(self) -> Intents: def voice_clients(self) -> List[VoiceProtocol]: return list(self._voice_clients.values()) - def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]: + def _get_voice_client(self, guild_id: Optional[GuildId]) -> Optional[VoiceProtocol]: # the keys of self._voice_clients are ints return self._voice_clients.get(guild_id) # type: ignore - def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None: + def _add_voice_client(self, guild_id: GuildId, voice: VoiceProtocol) -> None: self._voice_clients[guild_id] = voice - def _remove_voice_client(self, guild_id: int) -> None: + def _remove_voice_client(self, guild_id: GuildId) -> None: self._voice_clients.pop(guild_id, None) def _update_references(self, ws: DiscordWebSocket) -> None: @@ -383,7 +400,7 @@ def _update_references(self, ws: DiscordWebSocket) -> None: vc.main_ws = ws # type: ignore def store_user(self, data: UserPayload) -> User: - user_id = int(data["id"]) + user_id = UserId(int(data["id"])) try: return self._users[user_id] except KeyError: @@ -395,28 +412,28 @@ def store_user(self, data: UserPayload) -> User: def create_user(self, data: UserPayload) -> User: return User(state=self, data=data) - def get_user(self, id: Optional[int]) -> Optional[User]: + def get_user(self, id: Optional[UserId]) -> Optional[User]: # the keys of self._users are ints return self._users.get(id) # type: ignore def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: # the id will be present here - emoji_id = int(data["id"]) # type: ignore + emoji_id = EmojiId(int(data["id"])) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: - sticker_id = int(data["id"]) + sticker_id = StickerId(int(data["id"])) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: View, message_id: Optional[int] = None) -> None: + def store_view(self, view: View, message_id: Optional[MessageId] = None) -> None: self._view_store.add_view(view, message_id) - def store_modal(self, user_id: int, modal: Modal) -> None: + def store_modal(self, user_id: UserId, modal: Modal) -> None: self._modal_store.add_modal(user_id, modal) - def prevent_view_updates_for(self, message_id: int) -> Optional[View]: + def prevent_view_updates_for(self, message_id: MessageId) -> Optional[View]: return self._view_store.remove_message_tracking(message_id) @property @@ -427,7 +444,7 @@ def persistent_views(self) -> Sequence[View]: def guilds(self) -> List[Guild]: return list(self._guilds.values()) - def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: + def _get_guild(self, guild_id: Optional[GuildId]) -> Optional[Guild]: # the keys of self._guilds are ints if guild_id is None: return None @@ -448,9 +465,9 @@ def _remove_guild(self, guild: Guild) -> None: del guild def _get_global_application_command( - self, application_command_id: int + self, id: ApplicationCommandId ) -> Optional[APIApplicationCommand]: - return self._global_application_commands.get(application_command_id) + return self._global_application_commands.get(id) def _add_global_application_command( self, @@ -461,21 +478,21 @@ def _add_global_application_command( AssertionError("The provided application command does not have an ID") self._global_application_commands[application_command.id] = application_command - def _remove_global_application_command(self, application_command_id: int, /) -> None: - self._global_application_commands.pop(application_command_id, None) + def _remove_global_application_command(self, id: ApplicationCommandId, /) -> None: + self._global_application_commands.pop(id, None) def _clear_global_application_commands(self) -> None: self._global_application_commands.clear() def _get_guild_application_command( - self, guild_id: int, application_command_id: int + self, guild_id: GuildId, application_command_id: ApplicationCommandId ) -> Optional[APIApplicationCommand]: granula = self._guild_application_commands.get(guild_id) if granula is not None: return granula.get(application_command_id) def _add_guild_application_command( - self, guild_id: int, application_command: APIApplicationCommand + self, guild_id: GuildId, application_command: APIApplicationCommand ) -> None: if not application_command.id: AssertionError("The provided application command does not have an ID") @@ -487,14 +504,16 @@ def _add_guild_application_command( application_command.id: application_command } - def _remove_guild_application_command(self, guild_id: int, application_command_id: int) -> None: + def _remove_guild_application_command( + self, guild_id: GuildId, application_command_id: ApplicationCommandId + ) -> None: try: granula = self._guild_application_commands[guild_id] granula.pop(application_command_id, None) except KeyError: pass - def _clear_guild_application_commands(self, guild_id: int) -> None: + def _clear_guild_application_commands(self, guild_id: GuildId) -> None: self._guild_application_commands.pop(guild_id, None) def _get_global_command_named( @@ -505,7 +524,7 @@ def _get_global_command_named( return cmd def _get_guild_command_named( - self, guild_id: int, name: str, cmd_type: Optional[ApplicationCommandType] = None + self, guild_id: GuildId, name: str, cmd_type: Optional[ApplicationCommandType] = None ) -> Optional[APIApplicationCommand]: granula = self._guild_application_commands.get(guild_id, {}) for cmd in granula.values(): @@ -520,11 +539,11 @@ def emojis(self) -> List[Emoji]: def stickers(self) -> List[GuildSticker]: return list(self._stickers.values()) - def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]: + def get_emoji(self, emoji_id: Optional[EmojiId]) -> Optional[Emoji]: # the keys of self._emojis are ints return self._emojis.get(emoji_id) # type: ignore - def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: + def get_sticker(self, sticker_id: Optional[StickerId]) -> Optional[GuildSticker]: # the keys of self._stickers are ints return self._stickers.get(sticker_id) # type: ignore @@ -542,7 +561,7 @@ def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateCha self._private_channels.move_to_end(channel_id) # type: ignore return value - def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]: + def _get_private_channel_by_user(self, user_id: Optional[UserId]) -> Optional[DMChannel]: # the keys of self._private_channels are ints return self._private_channels_by_user.get(user_id) # type: ignore @@ -571,7 +590,7 @@ def _remove_private_channel(self, channel: PrivateChannel) -> None: if recipient is not None: self._private_channels_by_user.pop(recipient.id, None) - def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: + def _get_message(self, msg_id: Optional[MessageId]) -> Optional[Message]: return ( utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages @@ -598,19 +617,19 @@ def _get_guild_channel( self, data: Union[MessagePayload, gateway.TypingStartEvent], ) -> Tuple[Union[PartialChannel, Thread], Optional[Guild]]: - channel_id = int(data["channel_id"]) + channel_id = ChannelId(int(data["channel_id"])) try: - guild = self._get_guild(int(data["guild_id"])) + guild = self._get_guild(GuildId(int(data["guild_id"]))) except KeyError: # if we're here, this is a DM channel or an ephemeral message in a guild channel = self.get_channel(channel_id) if channel is None: if "author" in data: # MessagePayload - user_id = int(data["author"]["id"]) + user_id = UserId(int(data["author"]["id"])) else: # TypingStartEvent - user_id = int(data["user_id"]) + user_id = UserId(int(data["user_id"])) channel = DMChannel._from_message(self, channel_id, user_id) guild = None else: @@ -620,7 +639,7 @@ def _get_guild_channel( async def chunker( self, - guild_id: int, + guild_id: GuildId, query: str = "", limit: int = 0, presences: bool = False, @@ -637,7 +656,7 @@ async def query_members( guild: Guild, query: Optional[str], limit: int, - user_ids: Optional[List[int]], + user_ids: Optional[List[UserId]], cache: bool, presences: bool, ): @@ -1414,14 +1433,12 @@ def is_guild_evicted(self, guild) -> bool: @overload async def chunk_guild( self, guild: Guild, *, wait: Literal[False], cache: Optional[bool] = None - ) -> asyncio.Future[List[Member]]: - ... + ) -> asyncio.Future[List[Member]]: ... @overload async def chunk_guild( self, guild: Guild, *, wait: Literal[True] = True, cache: Optional[bool] = None - ) -> List[Member]: - ... + ) -> List[Member]: ... async def chunk_guild( self, guild: Guild, *, wait: bool = True, cache: Optional[bool] = None @@ -1975,7 +1992,7 @@ def parse_entitlement_delete(self, data: gateway.EntitlementDelete) -> None: self.dispatch("entitlement_delete", entitlement) def _get_reaction_user( - self, channel: MessageableChannel, user_id: int + self, channel: MessageableChannel, user_id: UserId ) -> Optional[Union[User, Member]]: if isinstance(channel, (TextChannel, VoiceChannel, Thread, StageChannel)): return channel.guild.get_member(user_id) @@ -2014,7 +2031,7 @@ def _get_emoji_from_fields( self, *, name: Optional[str], - id: Optional[int], + id: Optional[EmojiId], animated: Optional[bool] = False, ) -> Optional[Union[Emoji, PartialEmoji]]: """Convert partial emoji fields to proper emoji, if possible. @@ -2058,8 +2075,7 @@ def _get_partial_interaction_channel( guild: Optional[Union[Guild, Object]], *, return_messageable: Literal[False] = False, - ) -> AnyChannel: - ... + ) -> AnyChannel: ... @overload def _get_partial_interaction_channel( @@ -2068,8 +2084,7 @@ def _get_partial_interaction_channel( guild: Optional[Union[Guild, Object]], *, return_messageable: Literal[True], - ) -> MessageableChannel: - ... + ) -> MessageableChannel: ... # note: this resolves unknown types to `PartialMessageable` def _get_partial_interaction_channel( @@ -2080,7 +2095,7 @@ def _get_partial_interaction_channel( # this param is purely for type-checking, it has no effect on runtime behavior. return_messageable: bool = False, ) -> AnyChannel: - channel_id = int(data["id"]) + channel_id = ChannelId(int(data["id"])) channel_type = data["type"] factory, ch_type = _threaded_channel_factory(channel_type) @@ -2110,7 +2125,7 @@ def _get_partial_interaction_channel( ) ) - def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: + def get_channel(self, id: Union[ChannelId, ThreadId, None]) -> Optional[Union[Channel, Thread]]: if id is None: return None @@ -2143,10 +2158,13 @@ async def fetch_global_commands( *, with_localizations: bool = True, ) -> List[APIApplicationCommand]: - results = await self.http.get_global_commands(self.application_id, with_localizations=with_localizations) # type: ignore + results = await self.http.get_global_commands( + self.application_id, + with_localizations=with_localizations, # type: ignore + ) return [application_command_factory(data) for data in results] - async def fetch_global_command(self, command_id: int) -> APIApplicationCommand: + async def fetch_global_command(self, command_id: ApplicationCommandId) -> APIApplicationCommand: result = await self.http.get_global_command(self.application_id, command_id) # type: ignore return application_command_factory(result) @@ -2154,23 +2172,26 @@ async def create_global_command( self, application_command: ApplicationCommand ) -> APIApplicationCommand: result = await self.http.upsert_global_command( - self.application_id, application_command.to_dict() # type: ignore + self.application_id, # type: ignore + application_command.to_dict(), ) cmd = application_command_factory(result) self._add_global_application_command(cmd) return cmd async def edit_global_command( - self, command_id: int, new_command: ApplicationCommand + self, command_id: ApplicationCommandId, new_command: ApplicationCommand ) -> APIApplicationCommand: result = await self.http.edit_global_command( - self.application_id, command_id, new_command.to_dict() # type: ignore + self.application_id, # type: ignore + command_id, + new_command.to_dict(), ) cmd = application_command_factory(result) self._add_global_application_command(cmd) return cmd - async def delete_global_command(self, command_id: int) -> None: + async def delete_global_command(self, command_id: ApplicationCommandId) -> None: await self.http.delete_global_command(self.application_id, command_id) # type: ignore self._remove_global_application_command(command_id) @@ -2187,49 +2208,66 @@ async def bulk_overwrite_global_commands( async def fetch_guild_commands( self, - guild_id: int, + guild_id: GuildId, *, with_localizations: bool = True, ) -> List[APIApplicationCommand]: - results = await self.http.get_guild_commands(self.application_id, guild_id, with_localizations=with_localizations) # type: ignore + results = await self.http.get_guild_commands( + self.application_id, # type: ignore + guild_id, + with_localizations=with_localizations, + ) return [application_command_factory(data) for data in results] - async def fetch_guild_command(self, guild_id: int, command_id: int) -> APIApplicationCommand: + async def fetch_guild_command( + self, guild_id: GuildId, command_id: ApplicationCommandId + ) -> APIApplicationCommand: result = await self.http.get_guild_command(self.application_id, guild_id, command_id) # type: ignore return application_command_factory(result) async def create_guild_command( - self, guild_id: int, application_command: ApplicationCommand + self, guild_id: GuildId, application_command: ApplicationCommand ) -> APIApplicationCommand: result = await self.http.upsert_guild_command( - self.application_id, guild_id, application_command.to_dict() # type: ignore + self.application_id, # type: ignore + guild_id, + application_command.to_dict(), ) cmd = application_command_factory(result) self._add_guild_application_command(guild_id, cmd) return cmd async def edit_guild_command( - self, guild_id: int, command_id: int, new_command: ApplicationCommand + self, guild_id: GuildId, command_id: ApplicationCommandId, new_command: ApplicationCommand ) -> APIApplicationCommand: result = await self.http.edit_guild_command( - self.application_id, guild_id, command_id, new_command.to_dict() # type: ignore + self.application_id, # type: ignore + guild_id, + command_id, + new_command.to_dict(), ) cmd = application_command_factory(result) self._add_guild_application_command(guild_id, cmd) return cmd - async def delete_guild_command(self, guild_id: int, command_id: int) -> None: + async def delete_guild_command( + self, guild_id: GuildId, command_id: ApplicationCommandId + ) -> None: await self.http.delete_guild_command( - self.application_id, guild_id, command_id # type: ignore + self.application_id, # type: ignore + guild_id, + command_id, ) self._remove_guild_application_command(guild_id, command_id) async def bulk_overwrite_guild_commands( - self, guild_id: int, application_commands: List[ApplicationCommand] + self, guild_id: GuildId, application_commands: List[ApplicationCommand] ) -> List[APIApplicationCommand]: payload = [cmd.to_dict() for cmd in application_commands] results = await self.http.bulk_upsert_guild_commands( - self.application_id, guild_id, payload # type: ignore + self.application_id, # type: ignore + guild_id, + payload, ) commands = [application_command_factory(data) for data in results] self._guild_application_commands[guild_id] = {cmd.id: cmd for cmd in commands} @@ -2238,18 +2276,21 @@ async def bulk_overwrite_guild_commands( # Application command permissions async def bulk_fetch_command_permissions( - self, guild_id: int + self, guild_id: GuildId ) -> List[GuildApplicationCommandPermissions]: array = await self.http.get_guild_application_command_permissions( - self.application_id, guild_id # type: ignore + self.application_id, # type: ignore + guild_id, ) return [GuildApplicationCommandPermissions(state=self, data=obj) for obj in array] async def fetch_command_permissions( - self, guild_id: int, command_id: int + self, guild_id: GuildId, command_id: Union[ApplicationCommandId, ApplicationId] ) -> GuildApplicationCommandPermissions: data = await self.http.get_application_command_permissions( - self.application_id, guild_id, command_id # type: ignore + self.application_id, # type: ignore + guild_id, + command_id, ) return GuildApplicationCommandPermissions(state=self, data=data) @@ -2310,7 +2351,7 @@ def _update_member_references(self) -> None: async def chunker( self, - guild_id: int, + guild_id: GuildId, query: str = "", limit: int = 0, presences: bool = False, @@ -2427,7 +2468,7 @@ def parse_ready(self, data: gateway.ReadyEvent) -> None: pass else: if self.application_id is None: - self.application_id = utils._get_as_snowflake(application, "id") + self.application_id = utils._get_as_snowflake(application, "id", ApplicationId) self.application_flags = ApplicationFlags._from_value(application["flags"]) for guild_data in data["guilds"]: diff --git a/disnake/sticker.py b/disnake/sticker.py index 322d12485d..20a46ea8da 100644 --- a/disnake/sticker.py +++ b/disnake/sticker.py @@ -9,6 +9,7 @@ from .enums import StickerFormatType, StickerType, try_enum from .errors import InvalidData from .mixins import Hashable +from .types.ids import GuildId, StickerId from .utils import MISSING, _get_as_snowflake, cached_slot_property, find, get, snowflake_time __all__ = ( @@ -263,7 +264,7 @@ def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None: self._from_data(data) def _from_data(self, data: StickerPayload) -> None: - self.id: int = int(data["id"]) + self.id: StickerId = StickerId(int(data["id"])) self.name: str = data["name"] self.description: str = data.get("description") or "" self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"]) @@ -405,7 +406,7 @@ class GuildSticker(Sticker): def _from_data(self, data: GuildStickerPayload) -> None: super()._from_data(data) self.available: bool = data.get("available", True) - self.guild_id: int = int(data["guild_id"]) + self.guild_id: GuildId = GuildId(int(data["guild_id"])) user = data.get("user") self.user: Optional[User] = self._state.store_user(user) if user else None self.emoji: str = data["tags"] diff --git a/disnake/threads.py b/disnake/threads.py index b7a11f25c9..0c20e80ed3 100644 --- a/disnake/threads.py +++ b/disnake/threads.py @@ -13,6 +13,16 @@ from .mixins import Hashable from .partial_emoji import PartialEmoji, _EmojiTag from .permissions import Permissions +from .types.ids import ( + ChannelId, + EmojiId, + MemberId, + MessageId, + ThreadId, + UserId, + overload_fetch, + overload_get, +) from .utils import MISSING, _get_as_snowflake, _unique, parse_time, snowflake_time __all__ = ( @@ -166,7 +176,7 @@ class Thread(Messageable, Hashable): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None: self._state: ConnectionState = state self.guild: Guild = guild - self._members: Dict[int, ThreadMember] = {} + self._members: Dict[MemberId, ThreadMember] = {} self._from_data(data) async def _get_channel(self): @@ -183,12 +193,14 @@ def __str__(self) -> str: return self.name def _from_data(self, data: ThreadPayload) -> None: - self.id: int = int(data["id"]) - self.parent_id: int = int(data["parent_id"]) - self.owner_id: Optional[int] = _get_as_snowflake(data, "owner_id") + self.id: ThreadId = ThreadId(int(data["id"])) + self.parent_id: ChannelId = ChannelId(int(data["parent_id"])) + self.owner_id: Optional[MemberId] = _get_as_snowflake(data, "owner_id", MemberId) self.name: str = data["name"] self._type: ThreadType = try_enum(ChannelType, data["type"]) # type: ignore - self.last_message_id: Optional[int] = _get_as_snowflake(data, "last_message_id") + self.last_message_id: Optional[MessageId] = _get_as_snowflake( + data, "last_message_id", MessageId + ) self.slowmode_delay: int = data.get("rate_limit_per_user", 0) self.message_count: int = data.get("message_count") or 0 self.total_message_sent: int = data.get("total_message_sent") or 0 @@ -486,7 +498,7 @@ def permissions_for( return base - async def delete_messages(self, messages: Iterable[Snowflake]) -> None: + async def delete_messages(self, messages: Iterable[Snowflake[MessageId]]) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -798,7 +810,7 @@ async def leave(self) -> None: """ await self._state.http.leave_thread(self.id) - async def add_user(self, user: Snowflake) -> None: + async def add_user(self, user: Snowflake[UserId]) -> None: """|coro| Adds a user to this thread. @@ -822,7 +834,7 @@ async def add_user(self, user: Snowflake) -> None: """ await self._state.http.add_user_to_thread(self.id, user.id) - async def remove_user(self, user: Snowflake) -> None: + async def remove_user(self, user: Snowflake[UserId]) -> None: """|coro| Removes a user from this thread. @@ -843,7 +855,8 @@ async def remove_user(self, user: Snowflake) -> None: """ await self._state.http.remove_user_from_thread(self.id, user.id) - async def fetch_member(self, member_id: int, /) -> ThreadMember: + @overload_fetch + async def fetch_member(self, member_id: MemberId, /) -> ThreadMember: """|coro| Retrieves a single :class:`ThreadMember` from this thread. @@ -987,7 +1000,8 @@ async def remove_tags(self, *tags: Snowflake, reason: Optional[str] = None) -> N await self._state.http.edit_channel(self.id, applied_tags=new_tags, reason=reason) - def get_partial_message(self, message_id: int, /) -> PartialMessage: + @overload_get + def get_partial_message(self, message_id: MessageId, /) -> PartialMessage: """Creates a :class:`PartialMessage` from the message ID. This is useful if you want to work with a message and only have its ID without @@ -1012,7 +1026,7 @@ def get_partial_message(self, message_id: int, /) -> PartialMessage: def _add_member(self, member: ThreadMember) -> None: self._members[member.id] = member - def _pop_member(self, member_id: int) -> Optional[ThreadMember]: + def _pop_member(self, member_id: MemberId) -> Optional[ThreadMember]: return self._members.pop(member_id, None) @@ -1057,6 +1071,9 @@ class ThreadMember(Hashable): "_state", "parent", ) + if TYPE_CHECKING: + id: MemberId + thread_id: ThreadId def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None: self.parent = parent @@ -1070,14 +1087,14 @@ def __repr__(self) -> str: def _from_data(self, data: ThreadMemberPayload) -> None: try: - self.id = int(data["user_id"]) + self.id = MemberId(int(data["user_id"])) except KeyError as err: if (self_id := self._state.self_id) is None: raise AssertionError("self_id is None when updating our own ThreadMember.") from err - self.id = self_id + self.id = MemberId(self_id) try: - self.thread_id = int(data["id"]) + self.thread_id = ThreadId(int(data["id"])) except KeyError: self.thread_id = self.parent.id @@ -1210,7 +1227,7 @@ def to_dict(self) -> PartialForumTagPayload: def _from_data(cls, *, data: ForumTagPayload, state: ConnectionState) -> Self: emoji = state._get_emoji_from_fields( name=data.get("emoji_name"), - id=_get_as_snowflake(data, "emoji_id"), + id=_get_as_snowflake(data, "emoji_id", EmojiId), ) self = cls( diff --git a/disnake/types/ids.py b/disnake/types/ids.py new file mode 100644 index 0000000000..5b6dc19d70 --- /dev/null +++ b/disnake/types/ids.py @@ -0,0 +1,92 @@ +from typing import Any, Callable, Coroutine, List, NewType, Protocol, Union, overload + +from typing_extensions import Concatenate, Never, ParamSpec, TypeAlias, TypeVar + +__all__ = ( + "ApplicationCommandId", + "ApplicationId", + "AttachmentId", + "CategoryId", + "ChannelId", + "EmojiId", + "GuildId", + "InteractionId", + "MemberId", + "MessageId", + "PrivateChannelId", + "RoleId", + "StickerId", + "ThreadId", + "UserId", + "WebhookId", + "overload_fetch", + "overload_get", + "overload_get_seq", +) + +ApplicationCommandId = NewType("ApplicationCommandId", int) +ApplicationId = NewType("ApplicationId", int) +AttachmentId = NewType("AttachmentId", int) +ChannelId = NewType("ChannelId", int) +CategoryId: TypeAlias = ChannelId +EmojiId = NewType("EmojiId", int) +GuildId = NewType("GuildId", int) +InteractionId = NewType("InteractionId", int) +MessageId = NewType("MessageId", int) +PrivateChannelId = NewType("PrivateChannelId", int) +RoleId = NewType("RoleId", int) +StickerId = NewType("StickerId", int) +ThreadId = NewType("ThreadId", int) +UserId = NewType("UserId", int) +MemberId: TypeAlias = UserId +WebhookId = NewType("WebhookId", int) + +ChannelOrThreadId = Union[ChannelId, ThreadId] +ObjectId = Union[ + ApplicationCommandId, + ApplicationId, + AttachmentId, + ChannelId, + EmojiId, + GuildId, + InteractionId, + MessageId, + PrivateChannelId, + RoleId, + StickerId, + ThreadId, + UserId, + WebhookId, +] + +IdT = TypeVar("IdT", bound=ObjectId, infer_variance=True) +RetT = TypeVar("RetT", infer_variance=True) +RetInvalidT = TypeVar("RetInvalidT", infer_variance=True) +P = ParamSpec("P") + + +class AcceptsID(Protocol[IdT, P, RetT, RetInvalidT]): + @overload + def __call__(self, id: IdT, /, *args: P.args, **kwargs: P.kwargs) -> RetT: ... + @overload + def __call__(self, id: ObjectId, /, *args: P.args, **kwargs: P.kwargs) -> RetInvalidT: ... + @overload + def __call__(self, id: int, /, *args: P.args, **kwargs: P.kwargs) -> RetT: ... + + +def overload_fetch( + func: Callable[Concatenate[Any, IdT, P], Coroutine[Any, Any, RetT]], / +) -> AcceptsID[IdT, P, Coroutine[Any, Any, RetT], Coroutine[Any, Any, Never]]: + return func # type: ignore + + +def overload_get( + func: Callable[Concatenate[Any, IdT, P], RetT], / +) -> AcceptsID[IdT, P, RetT, None]: + return func # type: ignore + + +def overload_get_seq( + func: Callable[Concatenate[Any, IdT, P], List[RetT]], / +) -> AcceptsID[IdT, P, List[RetT], List[Never]]: + return func # type: ignore diff --git a/disnake/user.py b/disnake/user.py index b4a66505bb..531a18c5f5 100644 --- a/disnake/user.py +++ b/disnake/user.py @@ -10,6 +10,7 @@ from .colour import Colour from .enums import Locale, try_enum from .flags import PublicUserFlags +from .types.ids import UserId from .utils import MISSING, _assetbytes_to_base64_data, snowflake_time if TYPE_CHECKING: @@ -38,7 +39,7 @@ class _UserTag: __slots__ = () - id: int + id: UserId class BaseUser(_UserTag): @@ -59,7 +60,7 @@ class BaseUser(_UserTag): if TYPE_CHECKING: name: str - id: int + id: UserId discriminator: str global_name: Optional[str] bot: bool @@ -90,10 +91,10 @@ def __str__(self) -> str: # legacy behavior return f"{self.name}#{discriminator}" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, _UserTag) and other.id == self.id - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) def __hash__(self) -> int: @@ -101,7 +102,7 @@ def __hash__(self) -> int: def _update(self, data: Union[UserPayload, PartialUserPayload]) -> None: self.name = data["username"] - self.id = int(data["id"]) + self.id = UserId(int(data["id"])) self.discriminator = data["discriminator"] self.global_name = data.get("global_name") self._avatar = data["avatar"] diff --git a/disnake/utils.py b/disnake/utils.py index 9061cd0f61..6c94cc17bd 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -49,6 +49,7 @@ from urllib.parse import parse_qs, urlencode from .enums import Locale +from .types.ids import GuildId, IdT try: import orjson @@ -57,7 +58,6 @@ else: HAS_ORJSON = True - __all__ = ( "oauth_url", "snowflake_time", @@ -285,7 +285,7 @@ def oauth_url( client_id: Union[int, str], *, permissions: Permissions = MISSING, - guild: Snowflake = MISSING, + guild: Snowflake[GuildId] = MISSING, redirect_uri: str = MISSING, scopes: Iterable[str] = MISSING, disable_guild_select: bool = False, @@ -465,13 +465,16 @@ def _unique(iterable: Iterable[T]) -> List[T]: return list(dict.fromkeys(iterable)) -def _get_as_snowflake(data: Any, key: str) -> Optional[int]: +NumT = TypeVar("NumT", bound=int) + + +def _get_as_snowflake(data: Any, key: str, type_: Type[NumT] = int) -> Optional[NumT]: try: value = data[key] except KeyError: return None else: - return value and int(value) + return value and int(value) # type: ignore def _maybe_cast(value: V, converter: Callable[[V], T], default: T = None) -> Optional[T]: @@ -650,7 +653,7 @@ def valid_icon_size(size: int) -> bool: return not size & (size - 1) and 4096 >= size >= 16 -class SnowflakeList(array.array): +class SnowflakeList(array.array, Generic[IdT]): """Internal data storage class to efficiently store a list of snowflakes. This should have the following characteristics: @@ -666,21 +669,21 @@ class SnowflakeList(array.array): if TYPE_CHECKING: - def __init__(self, data: Iterable[int], *, is_sorted: bool = False) -> None: + def __init__(self, data: Iterable[IdT], *, is_sorted: bool = False) -> None: ... - def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): + def __new__(cls, data: Iterable[IdT], *, is_sorted: bool = False): return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore - def add(self, element: int) -> None: + def add(self, element: IdT) -> None: i = bisect_left(self, element) self.insert(i, element) - def get(self, element: int) -> Optional[int]: + def get(self, element: IdT) -> Optional[int]: i = bisect_left(self, element) return self[i] if i != len(self) and self[i] == element else None - def has(self, element: int) -> bool: + def has(self, element: IdT) -> bool: i = bisect_left(self, element) return i != len(self) and self[i] == element diff --git a/disnake/voice_client.py b/disnake/voice_client.py index e9469af670..5ef31ba469 100644 --- a/disnake/voice_client.py +++ b/disnake/voice_client.py @@ -28,6 +28,7 @@ from .errors import ClientException, ConnectionClosed from .gateway import DiscordVoiceWebSocket from .player import AudioPlayer, AudioSource +from .types.ids import ChannelId from .utils import MISSING if TYPE_CHECKING: @@ -486,7 +487,7 @@ async def disconnect(self, *, force: bool = False) -> None: if self.socket: self.socket.close() - async def move_to(self, channel: abc.Snowflake) -> None: + async def move_to(self, channel: abc.Snowflake[ChannelId]) -> None: """|coro| Moves you to a different voice channel. diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index 98650f4bf1..e649424df4 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -38,6 +38,15 @@ from ..message import Message from ..mixins import Hashable from ..object import Object +from ..types.ids import ( + ApplicationId, + ChannelId, + GuildId, + MessageId, + ThreadId, + WebhookId, + overload_fetch, +) from ..ui.action_row import MessageUIComponent, components_to_dict from ..user import BaseUser, User @@ -690,7 +699,7 @@ class PartialWebhookGuild(Hashable): def __init__(self, *, data, state) -> None: self._state = state - self.id = int(data["id"]) + self.id: GuildId = GuildId(int(data["id"])) self.name = data["name"] self._icon = data["icon"] @@ -723,7 +732,7 @@ def __init__( webhook: WebhookT, parent: Optional[Union[ConnectionState, _WebhookState]], *, - thread: Optional[Snowflake] = None, + thread: Optional[Snowflake[ThreadId]] = None, ) -> None: self._webhook: WebhookT = webhook @@ -733,7 +742,7 @@ def __init__( else: self._parent = parent - self._thread: Optional[Snowflake] = thread + self._thread: Optional[Snowflake[ThreadId]] = thread def _get_guild(self, guild_id): if self._parent is not None: @@ -959,10 +968,10 @@ def __init__( self._update(data) def _update(self, data: WebhookPayload) -> None: - self.id = int(data["id"]) + self.id: WebhookId = WebhookId(int(data["id"])) self.type = try_enum(WebhookType, int(data["type"])) - self.channel_id = utils._get_as_snowflake(data, "channel_id") - self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.channel_id = utils._get_as_snowflake(data, "channel_id", ChannelId) + self.guild_id = utils._get_as_snowflake(data, "guild_id", GuildId) self.name = data.get("name") self._avatar = data.get("avatar") self.token = data.get("token") @@ -985,7 +994,9 @@ def _update(self, data: WebhookPayload) -> None: self.source_guild: Optional[PartialWebhookGuild] = source_guild - self.application_id: Optional[int] = utils._get_as_snowflake(data, "application_id") + self.application_id: Optional[ApplicationId] = utils._get_as_snowflake( + data, "application_id", ApplicationId + ) def is_partial(self) -> bool: """Whether the webhook is a "partial" webhook. @@ -1366,7 +1377,7 @@ async def edit( reason: Optional[str] = None, name: Optional[str] = MISSING, avatar: Optional[AssetBytes] = MISSING, - channel: Optional[Snowflake] = None, + channel: Optional[Snowflake[ChannelId]] = None, prefer_auth: bool = True, ) -> Webhook: """|coro| @@ -1457,7 +1468,11 @@ async def edit( return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) def _create_message( - self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None + self, + data, + *, + thread: Optional[Snowflake[ThreadId]] = None, + thread_name: Optional[str] = None, ): channel_id = int(data["channel_id"]) @@ -1502,13 +1517,12 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - thread: Snowflake = ..., + thread: Snowflake[ThreadId] = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., wait: Literal[True], delete_after: float = ..., - ) -> WebhookMessage: - ... + ) -> WebhookMessage: ... @overload async def send( @@ -1529,13 +1543,12 @@ async def send( view: View = ..., components: Components[MessageUIComponent] = ..., poll: Poll = ..., - thread: Snowflake = ..., + thread: Snowflake[ThreadId] = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., wait: Literal[False] = ..., delete_after: float = ..., - ) -> None: - ... + ) -> None: ... async def send( self, @@ -1554,7 +1567,7 @@ async def send( allowed_mentions: AllowedMentions = MISSING, view: View = MISSING, components: Components[MessageUIComponent] = MISSING, - thread: Snowflake = MISSING, + thread: Snowflake[ThreadId] = MISSING, thread_name: str = MISSING, applied_tags: Sequence[Snowflake] = MISSING, wait: bool = False, @@ -1796,7 +1809,10 @@ async def send( return msg - async def fetch_message(self, id: int, *, thread: Optional[Snowflake] = None) -> WebhookMessage: + @overload_fetch + async def fetch_message( + self, id: MessageId, *, thread: Optional[Snowflake[ThreadId]] = None + ) -> WebhookMessage: """|coro| Retrieves a single :class:`WebhookMessage` owned by this webhook. @@ -1844,9 +1860,10 @@ async def fetch_message(self, id: int, *, thread: Optional[Snowflake] = None) -> ) return self._create_message(data, thread=thread) + @overload_fetch async def edit_message( self, - message_id: int, + message_id: MessageId, *, content: Optional[str] = MISSING, embed: Optional[Embed] = MISSING, @@ -1857,7 +1874,7 @@ async def edit_message( view: Optional[View] = MISSING, components: Optional[Components[MessageUIComponent]] = MISSING, allowed_mentions: Optional[AllowedMentions] = None, - thread: Optional[Snowflake] = None, + thread: Optional[Snowflake[ThreadId]] = None, ) -> WebhookMessage: """|coro| @@ -2009,7 +2026,7 @@ async def edit_message( return message async def delete_message( - self, message_id: int, /, *, thread: Optional[Snowflake] = None + self, message_id: int, /, *, thread: Optional[Snowflake[ThreadId]] = None ) -> None: """|coro| diff --git a/disnake/webhook/sync.py b/disnake/webhook/sync.py index bd9779db43..14e067bd6b 100644 --- a/disnake/webhook/sync.py +++ b/disnake/webhook/sync.py @@ -35,6 +35,7 @@ from ..http import Route from ..message import Message from ..object import Object +from ..types.ids import ChannelId, ThreadId from .async_ import BaseWebhook, _WebhookState, handle_message_parameters __all__ = ( @@ -799,7 +800,7 @@ def edit( reason: Optional[str] = None, name: Optional[str] = MISSING, avatar: Optional[bytes] = MISSING, - channel: Optional[Snowflake] = None, + channel: Optional[Snowflake[ChannelId]] = None, prefer_auth: bool = True, ) -> SyncWebhook: """Edits this Webhook. @@ -878,7 +879,11 @@ def edit( ) def _create_message( - self, data, *, thread: Optional[Snowflake] = None, thread_name: Optional[str] = None + self, + data, + *, + thread: Optional[Snowflake[ThreadId]] = None, + thread_name: Optional[str] = None, ): # see async webhook's _create_message for details channel_id = int(data["channel_id"]) @@ -907,12 +912,11 @@ def send( suppress_embeds: bool = ..., flags: MessageFlags = ..., allowed_mentions: AllowedMentions = ..., - thread: Snowflake = ..., + thread: Snowflake[ThreadId] = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., wait: Literal[True], - ) -> SyncWebhookMessage: - ... + ) -> SyncWebhookMessage: ... @overload def send( @@ -929,12 +933,11 @@ def send( suppress_embeds: bool = ..., flags: MessageFlags = ..., allowed_mentions: AllowedMentions = ..., - thread: Snowflake = ..., + thread: Snowflake[ThreadId] = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., wait: Literal[False] = ..., - ) -> None: - ... + ) -> None: ... def send( self, @@ -950,7 +953,7 @@ def send( suppress_embeds: bool = MISSING, flags: MessageFlags = MISSING, allowed_mentions: AllowedMentions = MISSING, - thread: Snowflake = MISSING, + thread: Snowflake[ThreadId] = MISSING, thread_name: str = MISSING, applied_tags: Sequence[Snowflake] = MISSING, wait: bool = False, @@ -1116,7 +1119,7 @@ def send( return self._create_message(data, thread=thread, thread_name=thread_name) def fetch_message( - self, id: int, /, *, thread: Optional[Snowflake] = None + self, id: int, /, *, thread: Optional[Snowflake[ThreadId]] = None ) -> SyncWebhookMessage: """Retrieves a single :class:`SyncWebhookMessage` owned by this webhook. @@ -1174,7 +1177,7 @@ def edit_message( files: List[File] = MISSING, attachments: Optional[List[Attachment]] = MISSING, allowed_mentions: Optional[AllowedMentions] = None, - thread: Optional[Snowflake] = None, + thread: Optional[Snowflake[ThreadId]] = None, ) -> SyncWebhookMessage: """Edits a message owned by this webhook. @@ -1284,7 +1287,9 @@ def edit_message( f.close() return self._create_message(data, thread=thread) - def delete_message(self, message_id: int, /, *, thread: Optional[Snowflake] = None) -> None: + def delete_message( + self, message_id: int, /, *, thread: Optional[Snowflake[ThreadId]] = None + ) -> None: """Deletes a message owned by this webhook. This is a lower level interface to :meth:`WebhookMessage.delete` in case diff --git a/disnake/widget.py b/disnake/widget.py index 4293985a36..37cf6cfbac 100644 --- a/disnake/widget.py +++ b/disnake/widget.py @@ -8,6 +8,7 @@ from .asset import Asset from .enums import Status, WidgetStyle, try_enum from .invite import Invite +from .types.ids import ChannelId from .user import BaseUser from .utils import MISSING, _get_as_snowflake, resolve_invite, snowflake_time @@ -223,7 +224,7 @@ async def edit( self, *, enabled: bool = MISSING, - channel: Optional[Snowflake] = MISSING, + channel: Optional[Snowflake[ChannelId]] = MISSING, reason: Optional[str] = None, ) -> WidgetSettings: """|coro| @@ -390,7 +391,7 @@ async def edit( self, *, enabled: bool = MISSING, - channel: Optional[Snowflake] = MISSING, + channel: Optional[Snowflake[ChannelId]] = MISSING, reason: Optional[str] = None, ) -> None: """|coro|