Skip to content

Commit

Permalink
Typed IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
Enegg committed Nov 15, 2024
1 parent 0961c9d commit e325f87
Show file tree
Hide file tree
Showing 31 changed files with 1,059 additions and 657 deletions.
82 changes: 42 additions & 40 deletions disnake/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
from .permissions import PermissionOverwrite, Permissions
from .role import Role
from .sticker import GuildSticker, StandardSticker, StickerItem
from .types.ids import (
CategoryId,
ChannelId,
GuildId,
MessageId,
PrivateChannelId,
UserId,
overload_fetch,
)
from .ui.action_row import components_to_dict
from .utils import _overload_with_permissions
from .voice_client import VoiceClient, VoiceProtocol
Expand Down Expand Up @@ -74,6 +83,7 @@
from .iterators import HistoryIterator
from .member import Member
from .message import Message, MessageReference, PartialMessage
from .mixins import IdT
from .poll import Poll
from .state import ConnectionState
from .threads import AnyThreadArchiveDuration, ForumTag
Expand All @@ -100,7 +110,7 @@


@runtime_checkable
class Snowflake(Protocol):
class Snowflake(Protocol["IdT"]):
"""An ABC that details the common operations on a Discord model.
Almost all :ref:`Discord models <discord_model>` meet this
Expand All @@ -116,11 +126,13 @@ class Snowflake(Protocol):
"""

__slots__ = ()
id: int

@property
def id(self) -> IdT: ...


@runtime_checkable
class User(Snowflake, Protocol):
class User(Snowflake[UserId], Protocol):
"""An ABC that details the common operations on a Discord user.
The following classes implement this ABC:
Expand Down Expand Up @@ -180,7 +192,7 @@ def avatar(self) -> Optional[Asset]:


@runtime_checkable
class PrivateChannel(Snowflake, Protocol):
class PrivateChannel(Snowflake[PrivateChannelId], Protocol):
"""An ABC that details the common operations on a private Discord channel.
The following classes implement this ABC:
Expand Down Expand Up @@ -255,12 +267,12 @@ class GuildChannel(ABC):

__slots__ = ()

id: int
id: ChannelId
name: str
guild: Guild
type: ChannelType
position: int
category_id: Optional[int]
category_id: Optional[CategoryId]
_flags: int
_state: ConnectionState
_overwrites: List[_Overwrites]
Expand All @@ -269,8 +281,7 @@ class GuildChannel(ABC):

def __init__(
self, *, state: ConnectionState, guild: Guild, data: Mapping[str, Any]
) -> None:
...
) -> None: ...

def __str__(self) -> str:
return self.name
Expand All @@ -285,7 +296,7 @@ def _update(self, guild: Guild, data: Dict[str, Any]) -> None:
async def _move(
self,
position: int,
parent_id: Optional[int] = None,
parent_id: Optional[CategoryId] = None,
lock_permissions: bool = False,
*,
reason: Optional[str],
Expand Down Expand Up @@ -330,7 +341,7 @@ async def _edit(
position: int = MISSING,
nsfw: bool = MISSING,
sync_permissions: bool = MISSING,
category: Optional[Snowflake] = MISSING,
category: Optional[Snowflake[CategoryId]] = MISSING,
slowmode_delay: Optional[int] = MISSING,
default_thread_slowmode_delay: Optional[int] = MISSING,
default_auto_archive_duration: Optional[AnyThreadArchiveDuration] = MISSING,
Expand All @@ -347,7 +358,7 @@ async def _edit(
default_layout: ThreadLayout = MISSING,
reason: Optional[str] = None,
) -> Optional[ChannelPayload]:
parent_id: Optional[int]
parent_id: Optional[CategoryId]
if category is not MISSING:
# if category is given, it's either `None` (no parent) or a category channel
parent_id = category.id if category else None
Expand Down Expand Up @@ -793,7 +804,7 @@ def permissions_for(

# Apply channel specific role permission overwrites
for overwrite in remaining_overwrites:
if overwrite.is_role() and roles.has(overwrite.id):
if overwrite.is_role() and roles.has(overwrite.id): # type: ignore
denies |= overwrite.deny
allows |= overwrite.allow

Expand Down Expand Up @@ -843,8 +854,7 @@ async def set_permissions(
*,
overwrite: Optional[PermissionOverwrite] = ...,
reason: Optional[str] = ...,
) -> None:
...
) -> None: ...

@overload
@_overload_with_permissions
Expand Down Expand Up @@ -911,8 +921,7 @@ async def set_permissions(
view_channel: Optional[bool] = ...,
view_creator_monetization_analytics: Optional[bool] = ...,
view_guild_insights: Optional[bool] = ...,
) -> None:
...
) -> None: ...

async def set_permissions(
self,
Expand Down Expand Up @@ -1030,7 +1039,7 @@ async def _clone_impl(
base_attrs: Dict[str, Any],
*,
name: Optional[str] = None,
category: Optional[Snowflake] = MISSING,
category: Optional[Snowflake[CategoryId]] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
reason: Optional[str] = None,
) -> Self:
Expand Down Expand Up @@ -1115,47 +1124,43 @@ async def move(
*,
beginning: bool,
offset: int = ...,
category: Optional[Snowflake] = ...,
category: Optional[Snowflake[CategoryId]] = ...,
sync_permissions: bool = ...,
reason: Optional[str] = ...,
) -> None:
...
) -> None: ...

@overload
async def move(
self,
*,
end: bool,
offset: int = ...,
category: Optional[Snowflake] = ...,
category: Optional[Snowflake[CategoryId]] = ...,
sync_permissions: bool = ...,
reason: Optional[str] = ...,
) -> None:
...
) -> None: ...

@overload
async def move(
self,
*,
before: Snowflake,
offset: int = ...,
category: Optional[Snowflake] = ...,
category: Optional[Snowflake[CategoryId]] = ...,
sync_permissions: bool = ...,
reason: Optional[str] = ...,
) -> None:
...
) -> None: ...

@overload
async def move(
self,
*,
after: Snowflake,
offset: int = ...,
category: Optional[Snowflake] = ...,
category: Optional[Snowflake[CategoryId]] = ...,
sync_permissions: bool = ...,
reason: Optional[str] = ...,
) -> None:
...
) -> None: ...

async def move(self, **kwargs: Any) -> None:
"""|coro|
Expand Down Expand Up @@ -1442,8 +1447,7 @@ async def send(
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...
) -> Message: ...

@overload
async def send(
Expand All @@ -1464,8 +1468,7 @@ async def send(
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...
) -> Message: ...

@overload
async def send(
Expand All @@ -1486,8 +1489,7 @@ async def send(
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...
) -> Message: ...

@overload
async def send(
Expand All @@ -1508,8 +1510,7 @@ async def send(
view: View = ...,
components: Components[MessageUIComponent] = ...,
poll: Poll = ...,
) -> Message:
...
) -> Message: ...

async def send(
self,
Expand Down Expand Up @@ -1817,7 +1818,8 @@ def typing(self) -> Typing:
"""
return Typing(self)

async def fetch_message(self, id: int, /) -> Message:
@overload_fetch
async def fetch_message(self, id: MessageId, /) -> Message:
"""|coro|
Retrieves a single :class:`.Message` from the destination.
Expand Down Expand Up @@ -1961,9 +1963,9 @@ class Connectable(Protocol):
__slots__ = ()
_state: ConnectionState
guild: Guild
id: int
id: ChannelId

def _get_voice_client_key(self) -> Tuple[int, str]:
def _get_voice_client_key(self) -> Tuple[GuildId, str]:
raise NotImplementedError

def _get_voice_state_pair(self) -> Tuple[int, int]:
Expand Down
15 changes: 8 additions & 7 deletions disnake/app_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from .i18n import Localized
from .permissions import Permissions
from .types.ids import ApplicationCommandId, ApplicationId, GuildId
from .utils import MISSING, _get_as_snowflake, _maybe_cast

if TYPE_CHECKING:
Expand Down Expand Up @@ -608,9 +609,9 @@ class _APIApplicationCommandMixin:
__repr_info__ = ("id",)

def _update_common(self, data: ApplicationCommandPayload) -> None:
self.id: int = int(data["id"])
self.application_id: int = int(data["application_id"])
self.guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
self.id: ApplicationCommandId = ApplicationCommandId(int(data["id"]))
self.application_id: ApplicationId = ApplicationId(int(data["application_id"]))
self.guild_id: Optional[GuildId] = _get_as_snowflake(data, "guild_id", GuildId)
self.version: int = int(data["version"])
# deprecated, but kept until API stops returning this field
self._default_permission = data.get("default_permission") is not False
Expand Down Expand Up @@ -1016,13 +1017,13 @@ class ApplicationCommandPermissions:

__slots__ = ("id", "type", "permission", "_guild_id")

def __init__(self, *, data: ApplicationCommandPermissionsPayload, guild_id: int) -> None:
def __init__(self, *, data: ApplicationCommandPermissionsPayload, guild_id: GuildId) -> None:
self.id: int = int(data["id"])
self.type: ApplicationCommandPermissionType = try_enum(
ApplicationCommandPermissionType, data["type"]
)
self.permission: bool = data["permission"]
self._guild_id: int = guild_id
self._guild_id: GuildId = guild_id

def __repr__(self) -> str:
return f"<ApplicationCommandPermissions id={self.id!r} type={self.type!r} permission={self.permission!r}>"
Expand Down Expand Up @@ -1079,8 +1080,8 @@ def __init__(
) -> None:
self._state: ConnectionState = state
self.id: int = int(data["id"])
self.application_id: int = int(data["application_id"])
self.guild_id: int = int(data["guild_id"])
self.application_id: ApplicationId = ApplicationId(int(data["application_id"]))
self.guild_id: GuildId = GuildId(int(data["guild_id"]))

self.permissions: List[ApplicationCommandPermissions] = [
ApplicationCommandPermissions(data=elem, guild_id=self.guild_id)
Expand Down
13 changes: 7 additions & 6 deletions disnake/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from . import utils
from .errors import DiscordException
from .file import File
from .types.ids import GuildId, MemberId, RoleId, UserId, WebhookId

__all__ = ("Asset",)

Expand Down Expand Up @@ -214,7 +215,7 @@ def _from_default_avatar(cls, state: AnyState, index: int) -> Self:
)

@classmethod
def _from_avatar(cls, state: AnyState, user_id: int, avatar: str) -> Self:
def _from_avatar(cls, state: AnyState, user_id: Union[UserId, WebhookId], avatar: str) -> Self:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -226,7 +227,7 @@ def _from_avatar(cls, state: AnyState, user_id: int, avatar: str) -> Self:

@classmethod
def _from_guild_avatar(
cls, state: AnyState, guild_id: int, member_id: int, avatar: str
cls, state: AnyState, guild_id: GuildId, member_id: MemberId, avatar: str
) -> Self:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
Expand Down Expand Up @@ -256,7 +257,7 @@ def _from_cover_image(cls, state: AnyState, object_id: int, cover_image_hash: st
)

@classmethod
def _from_guild_image(cls, state: AnyState, guild_id: int, image: str, path: str) -> Self:
def _from_guild_image(cls, state: AnyState, guild_id: GuildId, image: str, path: str) -> Self:
return cls(
state,
url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
Expand All @@ -265,7 +266,7 @@ def _from_guild_image(cls, state: AnyState, guild_id: int, image: str, path: str
)

@classmethod
def _from_guild_icon(cls, state: AnyState, guild_id: int, icon_hash: str) -> Self:
def _from_guild_icon(cls, state: AnyState, guild_id: GuildId, icon_hash: str) -> Self:
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand Down Expand Up @@ -296,7 +297,7 @@ def _from_banner(cls, state: AnyState, id: int, banner_hash: str) -> Self:
)

@classmethod
def _from_role_icon(cls, state: AnyState, role_id: int, icon_hash: str) -> Self:
def _from_role_icon(cls, state: AnyState, role_id: RoleId, icon_hash: str) -> Self:
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand Down Expand Up @@ -337,7 +338,7 @@ def __repr__(self) -> str:
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Asset) and self._url == other._url

def __hash__(self) -> int:
Expand Down
Loading

0 comments on commit e325f87

Please sign in to comment.