diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84222a420..a05cbfdfd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: - id: check-merge-conflict name: Merge Conflicts - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 23.1.0 hooks: - id: black name: Black Formatting @@ -45,7 +45,7 @@ repos: types: [file, python] args: [--max-line-length=100, --ignore=E203 E301 E302 E501 E402 E704 W503 W504] - repo: https://github.com/pycqa/isort - rev: 5.11.4 + rev: 5.12.0 hooks: - id: isort name: isort Formatting diff --git a/docs/events.rst b/docs/events.rst index 9699056d0..4af7801f4 100644 --- a/docs/events.rst +++ b/docs/events.rst @@ -52,6 +52,7 @@ There are several different internal events: - ``raw_socket_create`` - ``on_start`` + - ``on_disconnect`` - ``on_interaction`` - ``on_command`` - ``on_command_error`` @@ -82,6 +83,13 @@ This function takes no arguments. .. attention:: Unlike ``on_ready``, this event will never be dispatched more than once. +Event: ``on_disconnect`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +This event fires whenever the connection is invalidated and will often precede an ``on_ready`` event + +This function takes no arguments. + + Event: ``on_interaction`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^ This event fires on any interaction (commands, components, autocomplete and modals). diff --git a/interactions/api/cache.py b/interactions/api/cache.py index 03ca04265..3d0e92cf8 100644 --- a/interactions/api/cache.py +++ b/interactions/api/cache.py @@ -66,7 +66,6 @@ def merge(self, item: _T, id: Optional["Key"] = None) -> None: continue # we can only assume that discord did not provide it, falsely deleting is worse than not deleting if getattr(old_item, attrib) != getattr(item, attrib): - if isinstance(getattr(item, attrib), list) and not isinstance( getattr(old_item, attrib), list ): # could be None diff --git a/interactions/api/dispatch.py b/interactions/api/dispatch.py index bb06c37db..840311e46 100644 --- a/interactions/api/dispatch.py +++ b/interactions/api/dispatch.py @@ -37,7 +37,6 @@ def dispatch(self, name: str, /, *args, **kwargs) -> None: if converters := getattr(event, "_converters", None): _kwargs = kwargs.copy() for key, value in _kwargs.items(): - if key in converters.keys(): del kwargs[key] kwargs[converters[key]] = value diff --git a/interactions/api/error.py b/interactions/api/error.py index 77a7e8869..58957082b 100644 --- a/interactions/api/error.py +++ b/interactions/api/error.py @@ -317,7 +317,6 @@ def __init__(self, code: int = 0, message: str = None, severity: int = 0, **kwar self.log(self.message) if _fmt_error: - _flag: bool = ( self.message.lower() in self.lookup(self.code).lower() ) # creativity is hard diff --git a/interactions/api/gateway/client.py b/interactions/api/gateway/client.py index 73bce8ac4..8667284a2 100644 --- a/interactions/api/gateway/client.py +++ b/interactions/api/gateway/client.py @@ -17,7 +17,6 @@ wait_for, ) from contextlib import suppress -from enum import IntEnum from sys import platform, version_info from time import perf_counter from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -26,7 +25,7 @@ from aiohttp import ClientWebSocketResponse, WSMessage, WSMsgType from ...base import __version__, get_logger -from ...client.enums import ComponentType, InteractionType, OptionType +from ...client.enums import ComponentType, IntEnum, InteractionType, OptionType from ...client.models import Option from ...utils.missing import MISSING from ..dispatch import Listener @@ -242,7 +241,6 @@ async def _manage_heartbeat(self) -> None: """Manages the heartbeat loop.""" log.debug(f"Sending heartbeat every {self.__heartbeater.delay / 1000} seconds...") while not self.__heartbeat_event.is_set(): - log.debug("Sending heartbeat...") if not self.__heartbeater.event.is_set(): log.debug("HEARTBEAT_ACK missing, reconnecting...") @@ -404,7 +402,6 @@ def _dispatch_interaction_event(self, data: dict) -> None: _option = self.__sub_command_context(option, _context) __kwargs.update(_option) - self._dispatch.dispatch("on_command", _context) elif data["type"] == InteractionType.MESSAGE_COMPONENT: _name = f"component_{_context.data.custom_id}" @@ -931,6 +928,12 @@ async def _reconnect(self, to_resume: bool, code: Optional[int] = 1012) -> None: if self.__heartbeat_event.is_set(): self.__heartbeat_event.clear() # Because we're hardresetting the process + self._dispatch.dispatch( + "on_disconnect" + ) # will be followed by the on_ready event after reconnection + # reconnection happens whenever it disconnects either with or without a resume prompt + # as this is called whenever the WS client closes + if not to_resume: url = self.ws_url if self.ws_url else await self._http.get_gateway() else: @@ -963,7 +966,6 @@ async def __receive_packet(self, ignore_lock: bool = False) -> Optional[Dict[str buffer = bytearray() while True: - if not ignore_lock: # meaning if we're reconnecting or something because of tasks await self.__closed.wait() @@ -1003,7 +1005,6 @@ async def __receive_packet(self, ignore_lock: bool = False) -> Optional[Dict[str await self._reconnect(True) elif packet.type == WSMsgType.CLOSING: - if ignore_lock: raise LibraryException( message="Discord unexpectedly closing on receiving by force.", severity=50 @@ -1061,7 +1062,6 @@ async def _send_packet(self, data: Dict[str, Any]) -> None: await self._client.send_str(packet) else: async with self.reconnect_lock: # needs to lock while it reconnects. - if data["op"] != OpCodeType.HEARTBEAT.value: # This is because the ratelimiter limits already accounts for this. await self._ratelimiter.block() diff --git a/interactions/api/http/interaction.py b/interactions/api/http/interaction.py index 1c0945fcd..19f23f404 100644 --- a/interactions/api/http/interaction.py +++ b/interactions/api/http/interaction.py @@ -236,7 +236,6 @@ async def create_interaction_response( file_data = None if files: - file_data = MultipartWriter("form-data") part = file_data.append_json(data) part.set_content_disposition("form-data", name="payload_json") @@ -295,7 +294,6 @@ async def edit_interaction_response( # ^ again, I don't know if python will let me file_data = None if files: - file_data = MultipartWriter("form-data") part = file_data.append_json(data) part.set_content_disposition("form-data", name="payload_json") diff --git a/interactions/api/http/message.py b/interactions/api/http/message.py index a86e3ce11..c55c19794 100644 --- a/interactions/api/http/message.py +++ b/interactions/api/http/message.py @@ -84,7 +84,6 @@ async def create_message( data = None if files is not MISSING and len(files) > 0: - data = MultipartWriter("form-data") part = data.append_json(payload) part.set_content_disposition("form-data", name="payload_json") @@ -172,7 +171,6 @@ async def edit_message( """ data = None if files is not MISSING and len(files) > 0: - data = MultipartWriter("form-data") part = data.append_json(payload) part.set_content_disposition("form-data", name="payload_json") diff --git a/interactions/api/http/request.py b/interactions/api/http/request.py index fa6531589..2d13361d4 100644 --- a/interactions/api/http/request.py +++ b/interactions/api/http/request.py @@ -209,7 +209,7 @@ async def request(self, route: Route, **kwargs) -> Optional[Any]: await asyncio.sleep(_limiter.reset_after) continue if remaining is not None and int(remaining) == 0: - log.warning( + log.debug( f"The HTTP client has exhausted a per-route ratelimit. Locking route for {reset_after} seconds." ) self._loop.call_later(reset_after, _limiter.release_lock) diff --git a/interactions/api/http/thread.py b/interactions/api/http/thread.py index 6ae7e8cc9..ed1332c94 100644 --- a/interactions/api/http/thread.py +++ b/interactions/api/http/thread.py @@ -231,7 +231,6 @@ async def create_thread_in_forum( data = None if files is not MISSING and files and len(files) > 0: # edge case `None` - data = MultipartWriter("form-data") part = data.append_json(payload) part.set_content_disposition("form-data", name="payload_json") diff --git a/interactions/api/http/user.py b/interactions/api/http/user.py index d9008044b..54dde6ed3 100644 --- a/interactions/api/http/user.py +++ b/interactions/api/http/user.py @@ -10,7 +10,6 @@ class UserRequest: - _req: _Request cache: Cache diff --git a/interactions/api/http/webhook.py b/interactions/api/http/webhook.py index 467c83f39..fdc607710 100644 --- a/interactions/api/http/webhook.py +++ b/interactions/api/http/webhook.py @@ -12,7 +12,6 @@ class WebhookRequest: - _req: _Request cache: Cache @@ -121,7 +120,6 @@ async def execute_webhook( data = None if files is not MISSING and len(files) > 0: - data = MultipartWriter("form-data") part = data.append_json(payload) part.set_content_disposition("form-data", name="payload_json") diff --git a/interactions/api/models/audit_log.py b/interactions/api/models/audit_log.py index 7bf6f9134..1b75d97fb 100644 --- a/interactions/api/models/audit_log.py +++ b/interactions/api/models/audit_log.py @@ -1,8 +1,8 @@ # versionadded declared in docs gen file -from enum import IntEnum from typing import TYPE_CHECKING, List, Optional, TypeVar +from ...client.enums import IntEnum from ...utils.attrs_utils import DictSerializerMixin, convert_list, define, field from .channel import Channel from .misc import Snowflake @@ -235,7 +235,7 @@ class AuditLogEntry(DictSerializerMixin): :ivar Snowflake id: ID of the entry :ivar AuditLogEvents action_type: Type of action that occurred :ivar OptionalAuditEntryInfo options: Additional info for certain event types - :ivar str reason: Reason for the change (1-512 characters) + :ivar Optional[str] reason: Reason for the change (1-512 characters) """ target_id: Optional[str] = field(default=None) diff --git a/interactions/api/models/channel.py b/interactions/api/models/channel.py index 6080f798d..4183b3719 100644 --- a/interactions/api/models/channel.py +++ b/interactions/api/models/channel.py @@ -1,6 +1,5 @@ from asyncio import Task, create_task, get_running_loop, sleep from datetime import datetime, timedelta, timezone -from enum import IntEnum from inspect import isawaitable from math import inf from typing import ( @@ -17,11 +16,13 @@ ) from warnings import warn +from ...client.enums import IntEnum from ...utils.abc.base_context_managers import BaseAsyncContextManager from ...utils.abc.base_iterators import DiscordPaginationIterator from ...utils.attrs_utils import ( ClientSerializerMixin, DictSerializerMixin, + convert_int, convert_list, define, field, @@ -232,7 +233,6 @@ async def __anext__(self) -> "Message": obj = self.objects.pop(0) if self.check: - res = self.check(obj) _res = await res if isawaitable(res) else res while not _res: @@ -271,7 +271,6 @@ def __init__( obj: Union[int, str, "Snowflake", "Channel"], _client: "HTTPClient", ): - try: self.loop = get_running_loop() except RuntimeError as e: @@ -433,7 +432,7 @@ class Channel(ClientSerializerMixin, IDMixin): :ivar Optional[ThreadMetadata] thread_metadata: The thread metadata of the channel. :ivar Optional[ThreadMember] member: The member of the thread in the channel. :ivar Optional[int] default_auto_archive_duration: The set auto-archive time for all threads to naturally follow in the channel. - :ivar Optional[str] permissions: The permissions of the channel. + :ivar Optional[Permissions] permissions: The permissions of the channel. :ivar Optional[int] flags: The flags of the channel. :ivar Optional[int] total_message_sent: Number of messages ever sent in a thread. :ivar Optional[int] default_thread_slowmode_delay: The default slowmode delay in seconds for threads, if this channel is a forum. @@ -484,7 +483,9 @@ class Channel(ClientSerializerMixin, IDMixin): converter=ThreadMember, default=None, add_client=True, repr=False ) default_auto_archive_duration: Optional[int] = field(default=None) - permissions: Optional[str] = field(default=None, repr=False) + permissions: Optional[Permissions] = field( + converter=convert_int(Permissions), default=None, repr=False + ) flags: Optional[int] = field(default=None, repr=False) total_message_sent: Optional[int] = field(default=None, repr=False) default_thread_slowmode_delay: Optional[int] = field(default=None, repr=False) @@ -1322,7 +1323,6 @@ async def bulk_delete(): _allowed_time = datetime.now(tz=timezone.utc) - timedelta(days=14) _stop = False while amount > 100: - messages = [ Message(**res) for res in await self._client.get_channel_messages( @@ -1938,7 +1938,6 @@ async def create_forum_post( _content["attachments"].append(attach._json) else: - _data = await attach.download() __files.append(File(attach.filename, _data)) @@ -2070,7 +2069,7 @@ async def add_permission_overwrite( _id = int(id) _type = type - if not _type: + if _type is MISSING: raise LibraryException(12, "Please set the type of the overwrite!") overwrites.append(Overwrite(id=_id, type=_type, allow=allow, deny=deny)) diff --git a/interactions/api/models/flags.py b/interactions/api/models/flags.py index 438c8d77f..51afc3849 100644 --- a/interactions/api/models/flags.py +++ b/interactions/api/models/flags.py @@ -1,6 +1,16 @@ -from enum import Enum, IntFlag +from enum import IntFlag -__all__ = ("Intents", "AppFlags", "StatusType", "UserFlags", "Permissions", "MessageFlags") +from ...client.enums import StrEnum + +__all__ = ( + "Intents", + "AppFlags", + "StatusType", + "UserFlags", + "Permissions", + "MessageFlags", + "MemberFlags", +) class Intents(IntFlag): @@ -176,7 +186,7 @@ class AppFlags(IntFlag): APPLICATION_COMMAND_BADGE = 1 << 23 -class StatusType(str, Enum): +class StatusType(StrEnum): """ An enumerable object representing Discord status icons that a user may have. """ @@ -214,3 +224,21 @@ class MessageFlags(IntFlag): EPHEMERAL = 1 << 6 LOADING = 1 << 7 FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8 + + +class MemberFlags(IntFlag): + """ + .. versionadded:: 4.4.0 + + An integer flag bitshift object representing member flags on the guild. + + :ivar int DID_REJOIN: Member has left and rejoined the guild + :ivar int COMPLETED_ONBOARDING: Member has completed onboarding + :ivar int BYPASSES_VERIFICATION: Member bypasses guild verification requirements + :ivar int STARTED_ONBOARDING: Member has started onboarding + """ + + DID_REJOIN = 1 << 0 + COMPLETED_ONBOARDING = 1 << 1 + BYPASSES_VERIFICATION = 1 << 2 + STARTED_ONBOARDING = 1 << 3 diff --git a/interactions/api/models/guild.py b/interactions/api/models/guild.py index 277aaf5b2..fe44949b6 100644 --- a/interactions/api/models/guild.py +++ b/interactions/api/models/guild.py @@ -1,5 +1,4 @@ from datetime import datetime -from enum import Enum, IntEnum from inspect import isawaitable from math import inf from typing import ( @@ -16,10 +15,12 @@ ) from warnings import warn +from ...client.enums import IntEnum, StrEnum from ...utils.abc.base_iterators import DiscordPaginationIterator from ...utils.attrs_utils import ( ClientSerializerMixin, DictSerializerMixin, + convert_int, convert_list, define, field, @@ -29,6 +30,7 @@ from .audit_log import AuditLogEvents, AuditLogs from .channel import Channel, ChannelType, Thread, ThreadMember from .emoji import Emoji +from .flags import Permissions from .member import Member from .message import Sticker, StickerPack from .misc import ( @@ -123,7 +125,7 @@ class InviteTargetType(IntEnum): EMBEDDED_APPLICATION = 2 -class GuildFeatures(Enum): +class GuildFeatures(StrEnum): ANIMATED_BANNER = "ANIMATED_BANNER" ANIMATED_ICON = "ANIMATED_ICON" BANNER = "BANNER" @@ -246,7 +248,6 @@ def __init__( start_at: Optional[Union[int, str, Snowflake, Member]] = MISSING, check: Optional[Callable[[Member], Union[bool, Awaitable[bool]]]] = None, ): - self.__stop: bool = False super().__init__(obj, _client, maximum=maximum, start_at=start_at, check=check) @@ -256,7 +257,6 @@ def __init__( self.objects: Optional[List[Member]] async def get_first_objects(self) -> None: - limit = min(self.maximum, 1000) if self.maximum == limit: @@ -278,7 +278,6 @@ async def get_first_objects(self) -> None: ] async def get_objects(self) -> None: - limit = min(500, self.maximum - self.object_count) members = await self._client.get_list_of_members( guild_id=self.object_id, after=self.after, limit=limit @@ -307,7 +306,6 @@ async def __anext__(self) -> Member: obj = self.objects.pop(0) if self.check: - res = self.check(obj) _res = await res if isawaitable(res) else res while not _res: @@ -349,7 +347,7 @@ class Guild(ClientSerializerMixin, IDMixin): :ivar Optional[str] discovery_splash: The discovery splash banner of the guild. :ivar Optional[bool] owner: Whether the guild is owned. :ivar Snowflake owner_id: The ID of the owner of the guild. - :ivar Optional[str] permissions: The permissions of the guild. + :ivar Optional[Permissions] permissions: The permissions of the guild. :ivar Optional[str] region: The geographical region of the guild. :ivar Optional[Snowflake] afk_channel_id: The AFK voice channel of the guild. :ivar int afk_timeout: The timeout of the AFK voice channel of the guild. @@ -400,7 +398,9 @@ class Guild(ClientSerializerMixin, IDMixin): discovery_splash: Optional[str] = field(default=None, repr=False) owner: Optional[bool] = field(default=None) owner_id: Snowflake = field(converter=Snowflake, default=None) - permissions: Optional[str] = field(default=None, repr=False) + permissions: Optional[Permissions] = field( + converter=convert_int(Permissions), default=None, repr=False + ) region: Optional[str] = field(default=None, repr=False) # None, we don't do Voices. afk_channel_id: Optional[Snowflake] = field(converter=Snowflake, default=None) afk_timeout: Optional[int] = field(default=None) @@ -703,7 +703,7 @@ async def remove_member_role( async def create_role( self, name: str, - permissions: Optional[int] = MISSING, + permissions: Optional[Union[Permissions, int]] = MISSING, color: Optional[int] = 0, hoist: Optional[bool] = False, icon: Optional[Image] = MISSING, @@ -718,7 +718,7 @@ async def create_role( :param str name: The name of the role :param Optional[int] color: RGB color value as integer, default ``0`` - :param Optional[int] permissions: Bitwise value of the enabled/disabled permissions + :param Optional[Union[Permissions, int]] permissions: Bitwise value of the enabled/disabled permissions :param Optional[bool] hoist: Whether the role should be displayed separately in the sidebar, default ``False`` :param Optional[Image] icon: The role's icon image (if the guild has the ROLE_ICONS feature) :param Optional[str] unicode_emoji: The role's unicode emoji as a standard emoji (if the guild has the ROLE_ICONS feature) @@ -729,7 +729,8 @@ async def create_role( """ if not self._client: raise LibraryException(code=13) - _permissions = permissions if permissions is not MISSING else None + + _permissions = int(permissions) if permissions is not MISSING else None _icon = icon if icon is not MISSING else None _unicode_emoji = unicode_emoji if unicode_emoji is not MISSING else None payload = dict( @@ -839,7 +840,7 @@ async def modify_role( self, role_id: Union[int, Snowflake, Role], name: Optional[str] = MISSING, - permissions: Optional[int] = MISSING, + permissions: Optional[Union[Permissions, int]] = MISSING, color: Optional[int] = MISSING, hoist: Optional[bool] = MISSING, icon: Optional[Image] = MISSING, @@ -855,7 +856,7 @@ async def modify_role( :param Union[int, Snowflake, Role] role_id: The id of the role to edit :param Optional[str] name: The name of the role, defaults to the current value of the role :param Optional[int] color: RGB color value as integer, defaults to the current value of the role - :param Optional[int] permissions: Bitwise value of the enabled/disabled permissions, defaults to the current value of the role + :param Optional[Union[Permissions, int]] permissions: Bitwise value of the enabled/disabled permissions, defaults to the current value of the role :param Optional[bool] hoist: Whether the role should be displayed separately in the sidebar, defaults to the current value of the role :param Optional[Image] icon: The role's icon image (if the guild has the ROLE_ICONS feature), defaults to the current value of the role :param Optional[str] unicode_emoji: The role's unicode emoji as a standard emoji (if the guild has the ROLE_ICONS feature), defaults to the current value of the role @@ -876,7 +877,7 @@ async def modify_role( _color = role.color if color is MISSING else color _hoist = role.hoist if hoist is MISSING else hoist _mentionable = role.mentionable if mentionable is MISSING else mentionable - _permissions = role.permissions if permissions is MISSING else permissions + _permissions = int(role.permissions if permissions is MISSING else permissions) _icon = role.icon if icon is MISSING else icon _unicode_emoji = role.unicode_emoji if unicode_emoji is MISSING else unicode_emoji @@ -2065,7 +2066,6 @@ async def get_all_bans(self) -> List[Dict[str, User]]: res: list = await self._client.get_guild_bans(int(self.id), limit=1000) while len(res) >= 1000: - for ban in res: ban["user"] = User(**ban["user"]) _all.extend(res) diff --git a/interactions/api/models/gw.py b/interactions/api/models/gw.py index 6e5dfbbfa..54cf41619 100644 --- a/interactions/api/models/gw.py +++ b/interactions/api/models/gw.py @@ -9,6 +9,7 @@ define, field, ) +from .audit_log import AuditLogEntry from .channel import Channel, ThreadMember from .emoji import Emoji from .guild import EventMetadata, Guild @@ -40,6 +41,7 @@ "MessageDelete", "MessageReactionRemove", "MessageReaction", + "GuildAuditLogEntry", "GuildIntegrations", "GuildBan", "Webhooks", @@ -206,6 +208,27 @@ class EmbeddedActivity(DictSerializerMixin): channel_id: Snowflake = field(converter=Snowflake) +@define() +class GuildAuditLogEntry(AuditLogEntry): + """ + .. versionadded:: 4.4.0 + + A class object representing the event ``GUILD_AUDIT_LOG_ENTRY_CREATE``. + A derivation of :class:`.AuditLogEntry`. + + :ivar Snowflake guild_id: The guild ID of event. + :ivar Optional[str] target_id: ID of the affected entity (webhook, user, role, etc.) + :ivar Optional[List[AuditLogChange]] changes: Changes made to the target_id + :ivar Optional[Snowflake] user_id: User or app that made the changes + :ivar Snowflake id: ID of the entry + :ivar AuditLogEvents action_type: Type of action that occurred + :ivar OptionalAuditEntryInfo options: Additional info for certain event types + :ivar Optional[str] reason: Reason for the change (1-512 characters) + """ + + guild_id: Snowflake = field(converter=Snowflake) + + @define() class GuildBan(ClientSerializerMixin): """ @@ -648,6 +671,21 @@ def joined(self) -> bool: """ return self.channel_id is not None + @property + def channel(self) -> Optional[Channel]: + """ + Returns the current channel, if cached. + """ + return self._client.cache[Channel].get(self.channel_id) + + @property + def guild(self) -> Optional[Guild]: + """ + Returns the current guild, if cached. + """ + + return self._client.cache[Guild].get(self.guild_id) + async def mute_member(self, reason: Optional[str] = None) -> Member: """ Mutes the current member. @@ -689,15 +727,21 @@ async def get_channel(self) -> Channel: :rtype: Channel """ - return Channel( - **await self._client.get_channel(int(self.channel_id)), - _client=self._client, - ) + if channel := self.channel: + return channel + + res = await self._client.get_channel(int(self.channel_id)) + return Channel(**res, _client=self._client) - async def get_guild(self) -> "Guild": + async def get_guild(self) -> Guild: """ Gets the guild in what the update took place. :rtype: Guild """ - return Guild(**await self._client.get_guild(int(self.guild_id)), _client=self._client) + + if guild := self.guild: + return guild + + res = await self._client.get_guild(int(self.guild_id)) + return Guild(**res, _client=self._client) diff --git a/interactions/api/models/member.py b/interactions/api/models/member.py index d179950f3..e6b989371 100644 --- a/interactions/api/models/member.py +++ b/interactions/api/models/member.py @@ -6,7 +6,7 @@ from ...utils.utils import search_iterable from ..error import LibraryException from .channel import Channel -from .flags import Permissions +from .flags import MemberFlags, Permissions from .misc import AllowedMentions, File, IDMixin, Snowflake from .role import Role from .user import User @@ -38,6 +38,7 @@ class Member(ClientSerializerMixin, IDMixin): :ivar datetime premium_since: The timestamp the member has been a server booster since. :ivar bool deaf: Whether the member is deafened. :ivar bool mute: Whether the member is muted. + :ivar MemberFlags flags: The guild member flags. Default to 0. :ivar Optional[bool] pending: Whether the member is pending to pass membership screening. :ivar Optional[Permissions] permissions: Whether the member has permissions. :ivar Optional[str] communication_disabled_until: How long until they're unmuted, if any. @@ -53,6 +54,7 @@ class Member(ClientSerializerMixin, IDMixin): ) deaf: bool = field() mute: bool = field() + flags: MemberFlags = field(converter=convert_int(MemberFlags), repr=False) is_pending: Optional[bool] = field(default=None, repr=False) pending: Optional[bool] = field(default=None, repr=False) permissions: Optional[Permissions] = field( @@ -64,7 +66,6 @@ class Member(ClientSerializerMixin, IDMixin): hoisted_role: Optional[Any] = field( default=None, repr=False ) # TODO: Investigate what this is for when documented by Discord. - flags: int = field(repr=False) # TODO: Investigate what this is for when documented by Discord. def __getattr__(self, name): # Forward any attributes the user has to make it easier for devs diff --git a/interactions/api/models/message.py b/interactions/api/models/message.py index cc4a8b760..b8b65c401 100644 --- a/interactions/api/models/message.py +++ b/interactions/api/models/message.py @@ -1,9 +1,9 @@ import contextlib from datetime import datetime -from enum import IntEnum from io import BytesIO from typing import TYPE_CHECKING, List, Optional, Union +from ...client.enums import IntEnum from ...client.models.component import ActionRow, Button, SelectMenu from ...utils.attrs_utils import ( ClientSerializerMixin, @@ -111,7 +111,7 @@ class MessageActivity(DictSerializerMixin): @define() -class MessageReference(DictSerializerMixin): +class MessageReference(ClientSerializerMixin): """ A class object representing the "referenced"/replied message. @@ -384,9 +384,11 @@ class Embed(DictSerializerMixin): author: Optional[EmbedAuthor] = field(converter=EmbedAuthor, default=None) fields: Optional[List[EmbedField]] = field(converter=convert_list(EmbedField), default=None) - def add_field(self, name: str, value: str, inline: Optional[bool] = False) -> None: + def add_field(self, name: str, value: str, inline: Optional[bool] = False) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Adds a field to the embed @@ -399,21 +401,27 @@ def add_field(self, name: str, value: str, inline: Optional[bool] = False) -> No self.fields = [] self.fields.append(EmbedField(name=name, value=value, inline=inline)) + return self - def clear_fields(self) -> None: + def clear_fields(self) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Clears all the fields of the embed """ self.fields = [] + return self def insert_field_at( self, index: int, name: str, value: str, inline: Optional[bool] = False - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Inserts a field in the embed at the specified index @@ -427,12 +435,15 @@ def insert_field_at( self.fields = [] self.fields.insert(index, EmbedField(name=name, value=value, inline=inline)) + return self def set_field_at( self, index: int, name: str, value: str, inline: Optional[bool] = False - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Overwrites the field in the embed at the specified index @@ -449,10 +460,13 @@ def set_field_at( self.fields[index] = EmbedField(name=name, value=value, inline=inline) except IndexError as e: raise IndexError("No fields at this index") from e + return self - def remove_field(self, index: int) -> None: + def remove_field(self, index: int) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Remove field at the specified index @@ -466,16 +480,20 @@ def remove_field(self, index: int) -> None: self.fields.pop(index) except IndexError as e: raise IndexError("Field not Found at index") from e + return self - def remove_author(self) -> None: + def remove_author(self) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Removes the embed's author """ with contextlib.suppress(AttributeError): del self.author + return self def set_author( self, @@ -483,9 +501,11 @@ def set_author( url: Optional[str] = None, icon_url: Optional[str] = None, proxy_icon_url: Optional[str] = None, - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Sets the embed's author @@ -498,12 +518,15 @@ def set_author( self.author = EmbedAuthor( name=name, url=url, icon_url=icon_url, proxy_icon_url=proxy_icon_url ) + return self def set_footer( self, text: str, icon_url: Optional[str] = None, proxy_icon_url: Optional[str] = None - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Sets the embed's footer @@ -513,6 +536,7 @@ def set_footer( """ self.footer = EmbedFooter(text=text, icon_url=icon_url, proxy_icon_url=proxy_icon_url) + return self def set_image( self, @@ -520,9 +544,11 @@ def set_image( proxy_url: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Sets the embed's image @@ -533,6 +559,7 @@ def set_image( """ self.image = EmbedImageStruct(url=url, proxy_url=proxy_url, height=height, width=width) + return self def set_video( self, @@ -540,9 +567,11 @@ def set_video( proxy_url: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Sets the embed's video @@ -553,6 +582,7 @@ def set_video( """ self.video = EmbedImageStruct(url=url, proxy_url=proxy_url, height=height, width=width) + return self def set_thumbnail( self, @@ -560,9 +590,11 @@ def set_thumbnail( proxy_url: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, - ) -> None: + ) -> "Embed": """ .. versionadded:: 4.2.0 + .. versionchanged:: 4.4.0 + returns the embed instead of `None` Sets the embed's thumbnail @@ -573,6 +605,7 @@ def set_thumbnail( """ self.thumbnail = EmbedImageStruct(url=url, proxy_url=proxy_url, height=height, width=width) + return self @define() @@ -766,6 +799,9 @@ class Message(ClientSerializerMixin, IDMixin): position: Optional[int] = field(default=None, repr=False) def __attrs_post_init__(self): + if self.referenced_message is not None: + self.referenced_message = Message(**self.referenced_message, _client=self._client) + if self.member and self.guild_id: self.member._extras["guild_id"] = self.guild_id @@ -790,9 +826,6 @@ def created_at(self) -> datetime: """ return self.id.timestamp - if self.referenced_message is not None: - self.referenced_message = Message(**self.referenced_message, _client=self._client) - async def get_channel(self) -> Channel: """ .. versionadded:: 4.0.2 diff --git a/interactions/api/models/misc.py b/interactions/api/models/misc.py index 82624836b..d707e7b11 100644 --- a/interactions/api/models/misc.py +++ b/interactions/api/models/misc.py @@ -7,7 +7,6 @@ import datetime from base64 import b64encode -from enum import Enum, IntEnum from io import FileIO, IOBase from logging import Logger from math import floor @@ -15,7 +14,8 @@ from typing import List, Optional, Union from ...base import get_logger -from ...utils.attrs_utils import DictSerializerMixin, convert_list, define, field +from ...client.enums import IntEnum, StrEnum +from ...utils.attrs_utils import DictSerializerMixin, convert_int, convert_list, define, field from ...utils.missing import MISSING from ..error import LibraryException from .flags import Permissions @@ -48,14 +48,14 @@ class Overwrite(DictSerializerMixin): :ivar str id: Role or User ID :ivar int type: Type that corresponds ot the ID; 0 for role and 1 for member. - :ivar Union[Permissions, int, str] allow: Permission bit set. - :ivar Union[Permissions, int, str] deny: Permission bit set. + :ivar Permissions allow: Permission bit set. + :ivar Permissions deny: Permission bit set. """ id: int = field() type: int = field() - allow: Union[Permissions, int, str] = field() - deny: Union[Permissions, int, str] = field() + allow: Permissions = field(converter=convert_int(Permissions)) + deny: Permissions = field(converter=convert_int(Permissions)) @define() @@ -348,7 +348,6 @@ class File: def __init__( self, filename: str, fp: Optional[IOBase] = MISSING, description: Optional[str] = MISSING ): - if not isinstance(filename, str): raise LibraryException( message=f"File's first parameter 'filename' must be a string, not {str(type(filename))}", @@ -378,7 +377,6 @@ class Image: """ def __init__(self, file: Union[str, FileIO], fp: Optional[IOBase] = MISSING): - self._URI = "data:image/" if fp is MISSING or isinstance(file, FileIO): @@ -413,7 +411,7 @@ def filename(self) -> str: return self._name.split("/")[-1].split(".")[0] -class AllowedMentionType(str, Enum): +class AllowedMentionType(StrEnum): """ .. versionadded:: 4.3.2 diff --git a/interactions/api/models/presence.py b/interactions/api/models/presence.py index 219c2f2c7..77606065a 100644 --- a/interactions/api/models/presence.py +++ b/interactions/api/models/presence.py @@ -1,7 +1,7 @@ import time -from enum import IntEnum from typing import Any, List, Optional +from ...client.enums import IntEnum from ...utils.attrs_utils import DictSerializerMixin, convert_list, define, field from .emoji import Emoji from .flags import StatusType diff --git a/interactions/api/models/role.py b/interactions/api/models/role.py index 1c07e25fb..a7128ddd3 100644 --- a/interactions/api/models/role.py +++ b/interactions/api/models/role.py @@ -1,9 +1,16 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional, Union -from ...utils.attrs_utils import ClientSerializerMixin, DictSerializerMixin, define, field +from ...utils.attrs_utils import ( + ClientSerializerMixin, + DictSerializerMixin, + convert_int, + define, + field, +) from ...utils.missing import MISSING from ..error import LibraryException +from .flags import Permissions from .misc import IDMixin, Image, Snowflake if TYPE_CHECKING: @@ -52,7 +59,7 @@ class Role(ClientSerializerMixin, IDMixin): :ivar Optional[str] icon: Role icon hash, if any. :ivar Optional[str] unicode_emoji: Role unicode emoji :ivar int position: Role position - :ivar str permissions: Role permissions as a bit set + :ivar Permissions permissions: Role permissions as a bit set :ivar bool managed: A status denoting if this role is managed by an integration :ivar bool mentionable: A status denoting if this role is mentionable :ivar Optional[RoleTags] tags: The tags this role has @@ -65,7 +72,7 @@ class Role(ClientSerializerMixin, IDMixin): icon: Optional[str] = field(default=None, repr=False) unicode_emoji: Optional[str] = field(default=None) position: int = field() - permissions: str = field() + permissions: Permissions = field(converter=convert_int(Permissions)) managed: bool = field() mentionable: bool = field() tags: Optional[RoleTags] = field(converter=RoleTags, default=None) @@ -117,7 +124,7 @@ async def modify( self, guild_id: Union[int, Snowflake, "Guild"], name: Optional[str] = MISSING, - permissions: Optional[int] = MISSING, + permissions: Optional[Union[Permissions, int]] = MISSING, color: Optional[int] = MISSING, hoist: Optional[bool] = MISSING, icon: Optional[Image] = MISSING, @@ -133,7 +140,7 @@ async def modify( :param int guild_id: The id of the guild to edit the role on :param Optional[str] name: The name of the role, defaults to the current value of the role :param Optional[int] color: RGB color value as integer, defaults to the current value of the role - :param Optional[int] permissions: Bitwise value of the enabled/disabled permissions, defaults to the current value of the role + :param Optional[Union[Permissions, int]] permissions: Bitwise value of the enabled/disabled permissions, defaults to the current value of the role :param Optional[bool] hoist: Whether the role should be displayed separately in the sidebar, defaults to the current value of the role :param Optional[Image] icon: The role's icon image (if the guild has the ROLE_ICONS feature), defaults to the current value of the role :param Optional[str] unicode_emoji: The role's unicode emoji as a standard emoji (if the guild has the ROLE_ICONS feature), defaults to the current value of the role @@ -148,7 +155,7 @@ async def modify( _color = self.color if color is MISSING else color _hoist = self.hoist if hoist is MISSING else hoist _mentionable = self.mentionable if mentionable is MISSING else mentionable - _permissions = self.permissions if permissions is MISSING else permissions + _permissions = int(self.permissions if permissions is MISSING else permissions) _icon = self.icon if icon is MISSING else icon _unicode_emoji = self.unicode_emoji if unicode_emoji is MISSING else unicode_emoji _guild_id = int(guild_id) if isinstance(guild_id, (int, Snowflake)) else int(guild_id.id) diff --git a/interactions/api/models/team.py b/interactions/api/models/team.py index 1b23a1ca3..472ca7abd 100644 --- a/interactions/api/models/team.py +++ b/interactions/api/models/team.py @@ -1,8 +1,7 @@ from datetime import datetime -from enum import IntEnum from typing import Any, Dict, List, Optional, Union -from ...client.enums import Locale +from ...client.enums import IntEnum, Locale from ...utils.attrs_utils import ( ClientSerializerMixin, DictSerializerMixin, @@ -164,6 +163,13 @@ class ApplicationRoleConnectionMetadata(DictSerializerMixin): .. versionadded:: 4.4.0 A class object representing role connection metadata for the application/bot/client. + + :ivar ApplicationRoleConnectionMetadataType type: The type of metadata value. + :ivar str key: The dictionary key for the metadata field. + :ivar str name: The name of the metadata field. + :ivar Optional[Dict[Union[str, Locale], str]] name_localizations: The translations of the name field. + :ivar str description: The description of the metadata field. + :ivar Optional[Dict[Union[str, Locale], str]] description_localizations: The translations of the description field. """ type: ApplicationRoleConnectionMetadataType = field( diff --git a/interactions/api/models/webhook.py b/interactions/api/models/webhook.py index 9afeaf816..fff02cc79 100644 --- a/interactions/api/models/webhook.py +++ b/interactions/api/models/webhook.py @@ -1,9 +1,9 @@ # versionadded is specified in docs gen file from datetime import datetime -from enum import IntEnum from typing import TYPE_CHECKING, List, Optional, Union +from ...client.enums import IntEnum from ...utils.attrs_utils import ClientSerializerMixin, define, field from ...utils.missing import MISSING from ..error import LibraryException diff --git a/interactions/base.py b/interactions/base.py index 7194ef484..2093502ec 100644 --- a/interactions/base.py +++ b/interactions/base.py @@ -6,7 +6,7 @@ "__authors__", ) -__version__ = "4.4.0-beta.1" +__version__ = "4.4.0" __authors__ = { "current": [ diff --git a/interactions/client/bot.py b/interactions/client/bot.py index c2a540b8b..b7b3a6478 100644 --- a/interactions/client/bot.py +++ b/interactions/client/bot.py @@ -22,7 +22,7 @@ from ..api.models.misc import Image, Snowflake from ..api.models.presence import ClientPresence from ..api.models.role import Role -from ..api.models.team import Application +from ..api.models.team import Application, ApplicationRoleConnectionMetadata from ..api.models.user import User from ..base import get_logger from ..utils.attrs_utils import convert_list @@ -71,7 +71,7 @@ class Client: def __init__( self, - token: str, + token: Optional[str] = None, cache_limits: Optional[Dict[type, int]] = None, intents: Intents = Intents.DEFAULT, shards: Optional[List[Tuple[int]]] = None, @@ -189,11 +189,15 @@ def latency(self) -> float: return self._websocket.latency * 1000 - def start(self) -> None: - """Starts the client session.""" + def start(self, token: Optional[str] = None) -> None: + """ + Starts the client session. + + :param Optional[str] token: The token of bot. + """ try: - self._loop.run_until_complete(self._ready()) + self._loop.run_until_complete(self._ready(token=token)) except (CancelledError, Exception) as e: self._loop.run_until_complete(self._logout()) raise e from e @@ -365,7 +369,6 @@ def __check_options(command, data): return clean, _command elif command.get("options") and data.get("options"): - clean = __check_options(command, data) if not clean: @@ -396,7 +399,7 @@ def __check_options(command, data): return clean, _command - async def _ready(self) -> None: + async def _ready(self, token: Optional[str] = None) -> None: """ Prepares the client with an internal "ready" check to ensure that all conditions have been met in a chronological order: @@ -414,7 +417,19 @@ async def _ready(self) -> None: | |___ SYNCHRONIZE | |___ CALLBACK LOOP + + :param Optional[str] token: The token of bot. """ + if self._http and token and self._http is not token: + raise RuntimeError("You cannot pass a token to the bot twice!") + elif not (self._http or token): + raise RuntimeError("No token was passed to the bot!") + + if token: + self._token = token + self._http = token + self._websocket._http = token # Update the websockets token if it wasn't set before + if isinstance(self._http, str): self._http = HTTPClient(self._http, self._cache) @@ -470,7 +485,8 @@ async def _stop(self) -> None: self._websocket._closing_lock.set() # Toggles the "ready-to-shutdown" state for the bot. # And subsequently, the processes will close itself. - await self._http._req._session.close() # Closes the HTTP session associated with the client. + if isinstance(self._http, HTTPClient): + await self._http._req._session.close() # Closes the HTTP session associated with the client. async def _login(self) -> None: """Makes a login with the Discord API.""" @@ -521,7 +537,6 @@ async def _get_all_guilds(self) -> List[dict]: res = await self._http.get_self_guilds(limit=200) while len(res) >= 200: - _all.extend(res) _after = int(res[-1]["id"]) @@ -593,11 +608,10 @@ def __resolve_commands(self) -> None: # sourcery skip: low-code-quality cmd.listener = self._websocket._dispatch if cmd.default_scope and self._default_scope: - cmd.scope = ( + if isinstance(cmd.scope, list): cmd.scope.extend(self._default_scope) - if isinstance(cmd.scope, list) - else self._default_scope - ) + else: + cmd.scope = self._default_scope data: Union[dict, List[dict]] = cmd.full_data coro = cmd.dispatcher @@ -695,6 +709,9 @@ async def __sync(self) -> None: # sourcery no-metrics if _guild_id in __blocked_guilds: log.fatal(f"Cannot sync commands on guild with id {_guild_id}!") raise LibraryException(50001, message="Missing Access |") + if _guild_id not in _guild_ids: + log.warning(f"The bot is not in guild with id {_guild_id}") + continue if _guild_command["name"] not in __check_guild_commands[_guild_id]: self.__guild_commands[_guild_id]["clean"] = False self.__guild_commands[_guild_id]["commands"].append(_guild_command) @@ -1050,6 +1067,7 @@ def command( description_localizations: Optional[Dict[Union[str, Locale], str]] = MISSING, default_member_permissions: Optional[Union[int, Permissions]] = MISSING, dm_permission: Optional[bool] = MISSING, + nsfw: Optional[bool] = MISSING, default_scope: bool = True, ) -> Callable[[Callable[..., Coroutine]], Command]: """ @@ -1103,6 +1121,10 @@ async def sudo(ctx): The dictionary of localization for the ``description`` field. This enforces the same restrictions as the ``description`` field. :param Optional[Union[int, Permissions]] default_member_permissions: The permissions bit value of :class:`.Permissions`. If not given, defaults to :attr:`.Permissions.USE_APPLICATION_COMMANDS` :param Optional[bool] dm_permission: The application permissions if executed in a Direct Message. Defaults to ``True``. + :param Optional[bool] nsfw: + .. versionadded:: 4.4.0 + + Indicates whether the command is age-restricted. Defaults to ``False`` :param Optional[bool] default_scope: .. versionadded:: 4.3.0 @@ -1121,6 +1143,7 @@ def decorator(coro: Callable[..., Coroutine]) -> Command: scope=scope, default_member_permissions=default_member_permissions, dm_permission=dm_permission, + nsfw=nsfw, name_localizations=name_localizations, description_localizations=description_localizations, default_scope=default_scope, @@ -1139,6 +1162,7 @@ def message_command( name_localizations: Optional[Dict[Union[str, Locale], Any]] = MISSING, default_member_permissions: Optional[Union[int, Permissions]] = MISSING, dm_permission: Optional[bool] = MISSING, + nsfw: Optional[bool] = MISSING, default_scope: bool = True, ) -> Callable[[Callable[..., Coroutine]], Command]: """ @@ -1165,6 +1189,10 @@ async def context_menu_name(ctx): The dictionary of localization for the ``name`` field. This enforces the same restrictions as the ``name`` field. :param Optional[Union[int, Permissions]] default_member_permissions: The permissions bit value of :class:`.Permissions`. If not given, defaults to :attr:`.Permissions.USE_APPLICATION_COMMANDS` :param Optional[bool] dm_permission: The application permissions if executed in a Direct Message. Defaults to ``True``. + :param Optional[bool] nsfw: + .. versionadded:: 4.4.0 + + Indicates whether the command is age-restricted. Defaults to ``False`` :param Optional[bool] default_scope: .. versionadded:: 4.3.0 @@ -1180,6 +1208,7 @@ def decorator(coro: Callable[..., Coroutine]) -> Command: scope=scope, default_member_permissions=default_member_permissions, dm_permission=dm_permission, + nsfw=nsfw, name_localizations=name_localizations, default_scope=default_scope, )(coro) @@ -1194,6 +1223,7 @@ def user_command( name_localizations: Optional[Dict[Union[str, Locale], Any]] = MISSING, default_member_permissions: Optional[Union[int, Permissions]] = MISSING, dm_permission: Optional[bool] = MISSING, + nsfw: Optional[bool] = MISSING, default_scope: bool = True, ) -> Callable[[Callable[..., Coroutine]], Command]: """ @@ -1221,6 +1251,10 @@ async def context_menu_name(ctx): :param Optional[Union[int, Permissions]] default_member_permissions: The permissions bit value of :class:`.Permissions`. If not given, defaults to :attr:`.Permissions.USE_APPLICATION_COMMANDS` :param Optional[bool] dm_permission: The application permissions if executed in a Direct Message. Defaults to ``True``. + :param Optional[bool] nsfw: + .. versionadded:: 4.4.0 + + Indicates whether the command is age-restricted. Defaults to ``False`` :param Optional[bool] default_scope: .. versionadded:: 4.3.0 @@ -1236,6 +1270,7 @@ def decorator(coro: Callable[..., Coroutine]) -> Command: scope=scope, default_member_permissions=default_member_permissions, dm_permission=dm_permission, + nsfw=nsfw, name_localizations=name_localizations, default_scope=default_scope, )(coro) @@ -1476,7 +1511,6 @@ def remove( for ext_name, ext in getmembers( extension, lambda x: isinstance(x, type) and issubclass(x, Extension) ): - if ext_name != "Extension": _extension = self._extensions.get(ext_name) with contextlib.suppress(AttributeError): @@ -1602,7 +1636,8 @@ async def request_guild_members( async def _logout(self) -> None: await self._websocket.close() - await self._http._req.close() + if isinstance(self._http, HTTPClient): + await self._http._req.close() async def wait_for( self, @@ -1839,6 +1874,43 @@ async def get_self_user(self) -> User: """ return User(**await self._http.get_self(), _client=self._http) + async def get_role_connection_metadata(self) -> List[ApplicationRoleConnectionMetadata]: + """ + .. versionadded:: 4.4.0 + + Gets the bot's role connection metadata. + + :return: The list of bot's role connection metadata. + """ + + res: List[dict] = await self._http.get_application_role_connection_metadata( + application_id=int(self.me.id) + ) + return [ApplicationRoleConnectionMetadata(**metadata) for metadata in res] + + async def update_role_connection_metadata( + self, + metadata: Union[List[ApplicationRoleConnectionMetadata], ApplicationRoleConnectionMetadata], + ) -> List[ApplicationRoleConnectionMetadata]: + """ + .. versionadded:: 4.4.0 + + Updates the bot's role connection metadata. + + .. note:: + This method overwrites all current bot's role connection metadata. + + :param List[ApplicationRoleConnectionMetadata] metadata: The list of role connection metadata. The maximum is five. + :return: The updated list of bot's role connection metadata. + """ + if not isinstance(metadata, list): + metadata = [metadata] + + res: List[dict] = await self._http.update_application_role_connection_metadata( + application_id=int(self.me.id), payload=[_._json for _ in metadata] + ) + return [ApplicationRoleConnectionMetadata(**_) for _ in res] + class Extension: """ diff --git a/interactions/client/context.py b/interactions/client/context.py index 57c65117b..0926ccaf5 100644 --- a/interactions/client/context.py +++ b/interactions/client/context.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import suppress from datetime import datetime from logging import Logger from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -160,6 +161,55 @@ async def get_guild(self) -> Guild: res = await self._client.get_guild(int(self.guild_id)) return Guild(**res, _client=self._client) + async def defer( + self, ephemeral: Optional[bool] = False, edit_origin: Optional[bool] = False + ) -> Message: + """ + .. versionchanged:: 4.4.0 + Now returns the created message object + + This "defers" an interaction response, allowing up + to a 15-minute delay between invocation and responding. + + :param Optional[bool] ephemeral: Whether the deferred state is hidden or not. + :param Optional[bool] edit_origin: Whether you want to edit the original message or send a followup message + :return: The deferred message + :rtype: Message + """ + if edit_origin and self.type in { + InteractionType.APPLICATION_COMMAND, + InteractionType.APPLICATION_COMMAND_AUTOCOMPLETE, + }: + raise LibraryException( + message="You cannot defer with edit_origin parameter in this type of interaction" + ) + + if not self.responded: + self.deferred = True + is_ephemeral: int = MessageFlags.EPHEMERAL.value if bool(ephemeral) else 0 + # ephemeral doesn't change callback typings. just data json + self.callback = ( + InteractionCallbackType.DEFERRED_UPDATE_MESSAGE + if edit_origin + else InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE + ) + + await self._client.create_interaction_response( + token=self.token, + application_id=int(self.id), + data={"type": self.callback.value, "data": {"flags": is_ephemeral}}, + ) + + with suppress(LibraryException): + res = await self._client.get_original_interaction_response( + self.token, str(self.application_id) + ) + self.message = Message(**res, _client=self._client) + + self.responded = True + + return self.message + async def send( self, content: Optional[str] = MISSING, @@ -231,7 +281,6 @@ async def send( and self.message and self.callback == InteractionCallbackType.DEFERRED_UPDATE_MESSAGE ): - if isinstance(self.message.components, list): _components = self.message.components else: @@ -301,6 +350,9 @@ async def edit( :return: The edited message. """ + if self.message is None: + raise LibraryException(message="There is no message to edit.") + payload = {} if self.message.content is not None or content is not MISSING: @@ -414,6 +466,28 @@ async def has_permissions( else: return any(perm in self.author.permissions for perm in permissions) + async def delete(self) -> None: + """ + This deletes the interaction response of a message sent by + the contextually derived information from this class. + + .. note:: + Doing this will proceed in the context message no longer + being present. + """ + if self.responded and self.message is not None: + await self._client.delete_interaction_response( + application_id=str(self.application_id), + token=self.token, + message_id=int(self.message.id), + ) + else: + await self._client.delete_interaction_response( + application_id=str(self.application_id), token=self.token + ) + + self.message = None + @define() class CommandContext(_Context): @@ -434,7 +508,7 @@ class CommandContext(_Context): :ivar bool deferred: Whether the response was deferred or not. :ivar Optional[Locale] locale: The selected language of the user invoking the interaction. :ivar Optional[Locale] guild_locale: The guild's preferred language, if invoked in a guild. - :ivar str app_permissions: Bitwise set of permissions the bot has within the channel the interaction was sent from. + :ivar Permissions app_permissions: Bitwise set of permissions the bot has within the channel the interaction was sent from. :ivar Client client: .. versionadded:: 4.3.0 @@ -477,7 +551,6 @@ def __attrs_post_init__(self) -> None: async def edit( self, content: Optional[str] = MISSING, **kwargs ) -> Message: # sourcery skip: low-code-quality - payload, files = await super().edit(content, **kwargs) msg = None @@ -519,55 +592,23 @@ async def edit( else: self.message = msg = Message(**res, _client=self._client) else: - try: - res = await self._client.edit_interaction_response( - token=self.token, - application_id=str(self.application_id), - data=payload, - files=files, - ) - except LibraryException as e: - if e.code in {10015, 10018}: - log.warning(f"You can't edit hidden messages." f"({e.message}).") - else: - # if its not ephemeral or some other thing. - raise e from e - else: - self.message = msg = Message(**res, _client=self._client) - - return msg if msg is not None else Message(**payload, _client=self._client) - - async def defer(self, ephemeral: Optional[bool] = False) -> Message: - """ - .. versionchanged:: 4.4.0 - Now returns the created message object - - This "defers" an interaction response, allowing up - to a 15-minute delay between invocation and responding. - - :param Optional[bool] ephemeral: Whether the deferred state is hidden or not. - :return: The deferred message - :rtype: Message - """ - if not self.responded: - self.deferred = True - _ephemeral: int = MessageFlags.EPHEMERAL.value if ephemeral else 0 - self.callback = InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE + self.callback = InteractionCallbackType.UPDATE_MESSAGE await self._client.create_interaction_response( + data={"type": self.callback.value, "data": payload}, + files=files, token=self.token, application_id=int(self.id), - data={"type": self.callback.value, "data": {"flags": _ephemeral}}, ) - try: - _msg = await self._client.get_original_interaction_response( + + with suppress(LibraryException): + res = await self._client.get_original_interaction_response( self.token, str(self.application_id) ) - except LibraryException: - pass - else: - self.message = Message(**_msg, _client=self._client) + self.message = msg = Message(**res, _client=self._client) + self.responded = True - return self.message + + return msg or Message(**payload, _client=self._client) async def send(self, content: Optional[str] = MISSING, **kwargs) -> Message: payload, files = await super().send(content, **kwargs) @@ -594,47 +635,20 @@ async def send(self, content: Optional[str] = MISSING, **kwargs) -> Message: files=files, ) - try: - _msg = await self._client.get_original_interaction_response( + with suppress(LibraryException): + res = await self._client.get_original_interaction_response( self.token, str(self.application_id) ) - except LibraryException: - pass - else: - self.message = msg = Message(**_msg, _client=self._client) + self.message = msg = Message(**res, _client=self._client) self.responded = True - if msg is not None: - return msg - return Message( + return msg or Message( **payload, _client=self._client, author={"_client": self._client, "id": None, "username": None, "discriminator": None}, ) - async def delete(self) -> None: - """ - This deletes the interaction response of a message sent by - the contextually derived information from this class. - - .. note:: - Doing this will proceed in the context message no longer - being present. - """ - if self.responded and self.message is not None: - await self._client.delete_interaction_response( - application_id=str(self.application_id), - token=self.token, - message_id=int(self.message.id), - ) - else: - await self._client.delete_interaction_response( - application_id=str(self.application_id), token=self.token - ) - - self.message = None - async def populate(self, choices: Union[Choice, List[Choice]]) -> List[Choice]: """ This "populates" the list of choices that the client-end @@ -696,7 +710,7 @@ class ComponentContext(_Context): :ivar bool deferred: Whether the response was deferred or not. :ivar Optional[Locale] locale: The selected language of the user invoking the interaction. :ivar Optional[Locale] guild_locale: The guild's preferred language, if invoked in a guild. - :ivar str app_permissions: Bitwise set of permissions the bot has within the channel the interaction was sent from. + :ivar Permissions app_permissions: Bitwise set of permissions the bot has within the channel the interaction was sent from. """ async def edit(self, content: Optional[str] = MISSING, **kwargs) -> Message: @@ -712,14 +726,12 @@ async def edit(self, content: Optional[str] = MISSING, **kwargs) -> Message: application_id=int(self.id), ) - try: - _msg = await self._client.get_original_interaction_response( + with suppress(LibraryException): + res = await self._client.get_original_interaction_response( self.token, str(self.application_id) ) - except LibraryException: - pass - else: - self.message = msg = Message(**_msg, _client=self._client) + + self.message = msg = Message(**res, _client=self._client) self.responded = True elif self.callback != InteractionCallbackType.DEFERRED_UPDATE_MESSAGE: @@ -739,7 +751,7 @@ async def edit(self, content: Optional[str] = MISSING, **kwargs) -> Message: self.responded = True self.message = msg = Message(**res, _client=self._client) - return msg if msg is not None else Message(**payload, _client=self._client) + return msg or Message(**payload, _client=self._client) async def send(self, content: Optional[str] = MISSING, **kwargs) -> Message: payload, files = await super().send(content, **kwargs) @@ -766,60 +778,16 @@ async def send(self, content: Optional[str] = MISSING, **kwargs) -> Message: files=files, ) - try: - _msg = await self._client.get_original_interaction_response( + with suppress(LibraryException): + res = await self._client.get_original_interaction_response( self.token, str(self.application_id) ) - except LibraryException: - pass - else: - self.message = msg = Message(**_msg, _client=self._client) + self.message = msg = Message(**res, _client=self._client) self.responded = True return msg if msg is not None else Message(**payload, _client=self._client) - async def defer( - self, ephemeral: Optional[bool] = False, edit_origin: Optional[bool] = False - ) -> Message: - """ - .. versionchanged:: 4.4.0 - Now returns the created message object - - This "defers" a component response, allowing up - to a 15-minute delay between invocation and responding. - - :param Optional[bool] ephemeral: Whether the deferred state is hidden or not. - :param Optional[bool] edit_origin: Whether you want to edit the original message or send a followup message - :return: The deferred message - :rtype: Message - """ - if not self.responded: - - self.deferred = True - _ephemeral: int = MessageFlags.EPHEMERAL.value if bool(ephemeral) else 0 - # ephemeral doesn't change callback typings. just data json - if edit_origin: - self.callback = InteractionCallbackType.DEFERRED_UPDATE_MESSAGE - else: - self.callback = InteractionCallbackType.DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE - - await self._client.create_interaction_response( - token=self.token, - application_id=int(self.id), - data={"type": self.callback.value, "data": {"flags": _ephemeral}}, - ) - try: - _msg = await self._client.get_original_interaction_response( - self.token, str(self.application_id) - ) - except LibraryException: - pass - else: - self.message = Message(**_msg, _client=self._client) - self.responded = True - return self.message - async def disable_all_components( self, respond_to_interaction: Optional[bool] = True, **other_kwargs: Optional[dict] ) -> Message: @@ -876,3 +844,20 @@ def label(self) -> Optional[str]: for component in action_row.components: if component.custom_id == self.custom_id: return component.label + + @property + def component(self) -> Optional[Union[Button, SelectMenu]]: + """ + .. versionadded:: 4.4.0 + + The component that you interacted. + + :rtype: Optional[Union[Button, SelectMenu]] + """ + if self.message is None or self.message.components is None: + return + + for action_row in self.message.components: + for component in action_row.components: + if component.custom_id == self.custom_id: + return component diff --git a/interactions/client/decor.py b/interactions/client/decor.py index 1cb876ce1..91edfed4f 100644 --- a/interactions/client/decor.py +++ b/interactions/client/decor.py @@ -20,7 +20,8 @@ def command( name_localizations: Optional[Dict[Union[str, Locale], str]] = MISSING, description_localizations: Optional[Dict[Union[str, Locale], str]] = MISSING, default_member_permissions: Optional[Union[int, Permissions]] = MISSING, - dm_permission: Optional[bool] = MISSING + dm_permission: Optional[bool] = MISSING, + nsfw: Optional[bool] = MISSING, ) -> Union[List[dict], dict]: # sourcery skip: low-code-quality """ A wrapper designed to interpret the client-facing API for @@ -78,6 +79,7 @@ def command( ) ) _dm_permission: bool = True if dm_permission is MISSING else dm_permission + _nsfw: bool = False if nsfw is MISSING else nsfw payloads: list = [] @@ -102,6 +104,7 @@ def command( description_localizations=_description_localizations, default_member_permissions=_default_member_permissions, dm_permission=_dm_permission, + nsfw=_nsfw, ) payloads.append(payload._json) else: @@ -114,6 +117,7 @@ def command( description_localizations=_description_localizations, default_member_permissions=_default_member_permissions, dm_permission=_dm_permission, + nsfw=_nsfw, ) return payload._json diff --git a/interactions/client/enums.py b/interactions/client/enums.py index 6feaa467d..203c70fe6 100644 --- a/interactions/client/enums.py +++ b/interactions/client/enums.py @@ -1,6 +1,14 @@ -from enum import Enum, IntEnum +import logging +from enum import Enum +from typing import Any, Type + +from ..base import get_logger + +log: logging.Logger = get_logger("enums") __all__ = ( + "IntEnum", + "StrEnum", "ApplicationCommandType", "InteractionType", "InteractionCallbackType", @@ -13,6 +21,35 @@ ) +def _cursed_enum(cls: Type[Enum], obj: type, value: Any) -> Enum: + log.info(f"Enum class {cls.__name__} received an unexpected value `{value}`.") + + new = obj.__new__(cls) # type: ignore + new._name_ = f"UNKNOWN: {value}" + new._value_ = value + + return cls._value2member_map_.setdefault(value, new) + + +class IntEnum(int, Enum): + """Enum where members must be ints""" + + @classmethod + def _missing_(cls, value: int) -> Enum: + return _cursed_enum(cls, int, value) + + +class StrEnum(str, Enum): + """Enum where members must be strings""" + + @classmethod + def _missing_(cls, value: str) -> Enum: + return _cursed_enum(cls, str, value) + + def __str__(self): + return self.value + + class ApplicationCommandType(IntEnum): """ An enumerable object representing the types of application commands. @@ -168,7 +205,7 @@ class TextStyleType(IntEnum): PARAGRAPH = 2 -class Locale(str, Enum): +class Locale(StrEnum): """ .. versionadded:: 4.2.0 diff --git a/interactions/client/models/command.py b/interactions/client/models/command.py index 4d9f50023..6cceb31ef 100644 --- a/interactions/client/models/command.py +++ b/interactions/client/models/command.py @@ -173,9 +173,10 @@ class ApplicationCommand(DictSerializerMixin): :ivar str description: The description of the application command. :ivar Optional[List[Option]] options: The "options"/arguments of the application command. :ivar Optional[bool] default_permission: The default permission accessibility state of the application command. + :ivar Optional[bool] nsfw: Indicates whether the command is age-restricted. :ivar int version: The Application Command version autoincrement identifier. :ivar str default_member_permissions: The default member permission state of the application command. - :ivar boolean dm_permission: The application permissions if executed in a Direct Message. + :ivar bool dm_permission: The application permissions if executed in a Direct Message. :ivar Optional[Dict[Union[str, Locale], str]] name_localizations: The localisation dictionary for the application command name, if any. :ivar Optional[Dict[Union[str, Locale], str]] description_localizations: The localisation dictionary for the application command description, if any. """ @@ -188,6 +189,7 @@ class ApplicationCommand(DictSerializerMixin): description: str = field() options: Optional[List[Option]] = field(converter=convert_list(Option), default=None) default_permission: Optional[bool] = field(default=None) + nsfw: Optional[bool] = field(default=None) version: int = field(default=None) default_member_permissions: str = field() dm_permission: bool = field(default=None) @@ -386,6 +388,7 @@ class Command(DictSerializerMixin): :ivar Optional[Union[int, Guild, List[int], List[Guild]]] scope: The scope of the command. :ivar Optional[str] default_member_permissions: The default member permissions of the command. :ivar Optional[bool] dm_permission: The DM permission of the command. + :ivar Optional[bool] nsfw: Indicates whether the command is age-restricted. Defaults to ``False``. :ivar Optional[Dict[Union[str, Locale], str]] name_localizations: The dictionary of localization for the ``name`` field. This enforces the same restrictions as the ``name`` field. :ivar Optional[Dict[Union[str, Locale], str]] description_localizations: The dictionary of localization for the ``description`` field. This enforces the same restrictions as the ``description`` field. :ivar bool default_scope: Whether the command should use the default scope. Defaults to ``True``. @@ -407,6 +410,7 @@ class Command(DictSerializerMixin): scope: Optional[Union[int, Guild, List[int], List[Guild]]] = field(default=MISSING) default_member_permissions: Optional[str] = field(default=MISSING) dm_permission: Optional[bool] = field(default=MISSING) + nsfw: Optional[bool] = field(default=None) name_localizations: Optional[Dict[Union[str, Locale], str]] = field(default=MISSING) description_localizations: Optional[Dict[Union[str, Locale], str]] = field(default=MISSING) default_scope: bool = field(default=True) @@ -479,6 +483,7 @@ def full_data(self) -> Union[dict, List[dict]]: description_localizations=self.description_localizations, default_member_permissions=self.default_member_permissions, dm_permission=self.dm_permission, + nsfw=self.nsfw, ) @property @@ -897,6 +902,8 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs): ctx.command = self ctx.extension = self.extension + self.listener.dispatch("on_command", ctx) + try: if self.extension: return await coro(self.extension, ctx, *args, **kwargs) diff --git a/interactions/client/models/component.py b/interactions/client/models/component.py index db4f677db..194f27919 100644 --- a/interactions/client/models/component.py +++ b/interactions/client/models/component.py @@ -277,7 +277,6 @@ def new(cls, *components: Union[Button, SelectMenu, TextInput]) -> List["ActionR def _build_components(components) -> List[dict]: # sourcery no-metrics def __check_action_row(): - if isinstance(components, list) and all( isinstance(action_row, (list, ActionRow)) for action_row in components ): diff --git a/interactions/client/models/misc.py b/interactions/client/models/misc.py index 3599d6df2..a440351ad 100644 --- a/interactions/client/models/misc.py +++ b/interactions/client/models/misc.py @@ -27,12 +27,18 @@ class InteractionResolvedData(DictSerializerMixin): :ivar Dict[str, Attachment] attachments: The resolved attachments data. """ - users: Dict[str, User] = field(converter=convert_dict(value_converter=User)) - members: Dict[str, Member] = field(converter=convert_dict(value_converter=Member)) - roles: Dict[str, Role] = field(converter=convert_dict(value_converter=Role)) - channels: Dict[str, Channel] = field(converter=convert_dict(value_converter=Channel)) - messages: Dict[str, Message] = field(converter=convert_dict(value_converter=Message)) - attachments: Dict[str, Attachment] = field(converter=convert_dict(value_converter=Attachment)) + users: Dict[str, User] = field(converter=convert_dict(value_converter=User), factory=dict) + members: Dict[str, Member] = field(converter=convert_dict(value_converter=Member), factory=dict) + roles: Dict[str, Role] = field(converter=convert_dict(value_converter=Role), factory=dict) + channels: Dict[str, Channel] = field( + converter=convert_dict(value_converter=Channel), factory=dict + ) + messages: Dict[str, Message] = field( + converter=convert_dict(value_converter=Message), factory=dict + ) + attachments: Dict[str, Attachment] = field( + converter=convert_dict(value_converter=Attachment), factory=dict + ) def __attrs_post_init__(self): if self.members: diff --git a/interactions/ext/error.py b/interactions/ext/error.py index 426c1ce52..094bdc013 100644 --- a/interactions/ext/error.py +++ b/interactions/ext/error.py @@ -1,4 +1,4 @@ -from enum import Enum +from ..client.enums import StrEnum __all__ = ( "ErrorType", @@ -9,7 +9,7 @@ ) -class ErrorType(str, Enum): +class ErrorType(StrEnum): """ An enumerable object representing the type of error responses raised. diff --git a/interactions/ext/version.py b/interactions/ext/version.py index 2b363dc8f..096b23265 100644 --- a/interactions/ext/version.py +++ b/interactions/ext/version.py @@ -1,8 +1,8 @@ -from enum import Enum from hashlib import md5 from string import ascii_lowercase from typing import List, Optional, Union +from ..client.enums import StrEnum from .error import IncorrectAlphanumeric, TooManyAuthors __all__ = ( @@ -12,7 +12,7 @@ ) -class VersionAlphanumericType(str, Enum): +class VersionAlphanumericType(StrEnum): ALPHA = "alpha" BETA = "beta" RELEASE_CANDIDATE = "rc" diff --git a/interactions/utils/get.py b/interactions/utils/get.py index 917604e07..6501a73d3 100644 --- a/interactions/utils/get.py +++ b/interactions/utils/get.py @@ -1,7 +1,6 @@ # versionadded declared in docs gen file from asyncio import sleep -from enum import Enum from inspect import isawaitable from logging import getLogger from sys import version_info @@ -28,6 +27,7 @@ class GenericAlias: from ..api.models.message import Message from ..api.models.misc import Snowflake from ..api.models.role import Role +from ..client.enums import StrEnum log = getLogger("get") @@ -43,7 +43,7 @@ class GenericAlias: ) -class Force(str, Enum): +class Force(StrEnum): """ An enumerable object representing the force types for the get method. @@ -318,7 +318,6 @@ def _get_cache( def _resolve_kwargs(obj, **kwargs): # This function is needed to get correct kwarg names if __id := kwargs.pop("parent_id", None): - if version_info >= (3, 9): _list = [Message, List[Message], list[Message]] else: diff --git a/interactions/utils/get.pyi b/interactions/utils/get.pyi index 6ecf191fa..cf2094433 100644 --- a/interactions/utils/get.pyi +++ b/interactions/utils/get.pyi @@ -1,8 +1,7 @@ -from enum import Enum from typing import Awaitable, Coroutine, List, Literal, Optional, Type, TypeVar, Union, overload -from interactions.client.bot import Client - +from ..client.bot import Client +from ..client.enums import StrEnum from ..api.http.client import HTTPClient from ..api.models.channel import Channel from ..api.models.guild import Guild @@ -19,7 +18,7 @@ _T = TypeVar("_T") __all__: tuple -class Force(str, Enum): +class Force(StrEnum): """ An enum representing the force methods for the get method """ diff --git a/interactions/utils/utils.py b/interactions/utils/utils.py index c6ccacf83..db39b0be8 100644 --- a/interactions/utils/utils.py +++ b/interactions/utils/utils.py @@ -232,7 +232,7 @@ def disable_components( elif isinstance(components, list): if not all( isinstance(component, (Button, SelectMenu)) for component in components - ) or not all(isinstance(component, (ActionRow, Component)) for component in components): + ) and not all(isinstance(component, (ActionRow, Component)) for component in components): raise LibraryException( 12, "You must only specify lists of 'Buttons' and 'SelectMenus' or 'ActionRow' and 'Component'",