diff --git a/hikari/api/entity_factory.py b/hikari/api/entity_factory.py index ca2fce649a..277eaa52cf 100644 --- a/hikari/api/entity_factory.py +++ b/hikari/api/entity_factory.py @@ -31,7 +31,6 @@ import attr from hikari import undefined -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import applications as application_models @@ -53,8 +52,7 @@ from hikari.internal import data_binding -@attr_extensions.with_copy -@attr.define(weakref_slot=False) +@attr.frozen(weakref_slot=False) class GatewayGuildDefinition: """A structure for handling entities within guild create and update events.""" diff --git a/hikari/api/event_manager.py b/hikari/api/event_manager.py index 90a0c2c2b9..126fd43b63 100644 --- a/hikari/api/event_manager.py +++ b/hikari/api/event_manager.py @@ -97,7 +97,7 @@ def dispatch(self, event: EventT_inv) -> asyncio.Future[typing.Any]: from hikari.users import User from hikari.snowflakes import Snowflake - @attr.define() + @attr.frozen() class EveryoneMentionedEvent(Event): app: RESTAware = attr.field() diff --git a/hikari/applications.py b/hikari/applications.py index 2d3a14bd80..fb4bf5556c 100644 --- a/hikari/applications.py +++ b/hikari/applications.py @@ -52,7 +52,6 @@ from hikari import snowflakes from hikari import urls from hikari import users -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -232,8 +231,7 @@ class ConnectionVisibility(int, enums.Enum): """Everyone can see the connection.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class OwnConnection: """Represents a user's connection with a third party account. @@ -272,7 +270,7 @@ class OwnConnection: """The visibility of the connection.""" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class OwnGuild(guilds.PartialGuild): """Represents a user bound partial guild object.""" @@ -297,8 +295,7 @@ class TeamMembershipState(int, enums.Enum): """Denotes the user has accepted the invite and is now a member.""" -@attr_extensions.with_copy -@attr.define(eq=False, hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(eq=False, hash=False, kw_only=True, weakref_slot=False) class TeamMember(users.User): """Represents a member of a Team.""" @@ -347,10 +344,6 @@ def flags(self) -> users.UserFlag: def id(self) -> snowflakes.Snowflake: return self.user.id - @id.setter - def id(self, value: snowflakes.Snowflake) -> typing.NoReturn: - raise TypeError("Cannot mutate the ID of a member") - @property def is_bot(self) -> bool: return self.user.is_bot @@ -380,14 +373,11 @@ def __eq__(self, other: object) -> bool: return self.user == other -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Team(snowflakes.Unique): """Represents a development team, along with all its members.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -461,14 +451,11 @@ def make_icon_url(self, *, ext: str = "png", size: int = 4096) -> typing.Optiona ) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Application(guilds.PartialApplication): """Represents the information of an Oauth2 Application.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" is_bot_public: typing.Optional[bool] = attr.field(eq=False, hash=False, repr=True) @@ -565,8 +552,7 @@ def make_cover_image_url(self, *, ext: str = "png", size: int = 4096) -> typing. ) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class AuthorizationApplication(guilds.PartialApplication): """The application model found attached to `AuthorizationInformation`.""" @@ -592,8 +578,7 @@ class AuthorizationApplication(guilds.PartialApplication): """The URL of this application's privacy policy.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class AuthorizationInformation: """Model for the data returned by Get Current Authorization Information.""" @@ -610,8 +595,7 @@ class AuthorizationInformation: """The user who has authorized this token if they included the `identify` scope.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialOAuth2Token: """Model for partial OAuth2 token data returned by the API. @@ -635,8 +619,7 @@ def __str__(self) -> str: return self.access_token -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class OAuth2AuthorizationToken(PartialOAuth2Token): """Model for the OAuth2 token data returned by the authorization grant flow.""" @@ -658,8 +641,7 @@ class OAuth2AuthorizationToken(PartialOAuth2Token): """ -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class OAuth2ImplicitToken(PartialOAuth2Token): """Model for the OAuth2 token data returned by the implicit grant flow.""" diff --git a/hikari/audit_logs.py b/hikari/audit_logs.py index 28cf5133f1..63b09092d0 100644 --- a/hikari/audit_logs.py +++ b/hikari/audit_logs.py @@ -40,14 +40,12 @@ "MessagePinEntryInfo", ] -import abc import typing import attr from hikari import channels from hikari import snowflakes -from hikari.internal import attr_extensions from hikari.internal import collections from hikari.internal import enums @@ -126,8 +124,7 @@ class AuditLogChangeKey(str, enums.Enum): """Alias for "COLOR""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class AuditLogChange: """Represents a change made to an audit log entry's target entity.""" @@ -182,16 +179,15 @@ class AuditLogEventType(int, enums.Enum): INTEGRATION_DELETE = 82 -@attr.define(hash=False, kw_only=True, weakref_slot=False) -class BaseAuditLogEntryInfo(abc.ABC): +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) +class BaseAuditLogEntryInfo: """A base object that all audit log entry info objects will inherit from.""" - app: traits.RESTAware = attr.field(repr=False, eq=False, metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field(repr=False, eq=False) """The client application that models may use for procedures.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class ChannelOverwriteEntryInfo(BaseAuditLogEntryInfo, snowflakes.Unique): """Represents the extra information for overwrite related audit log entries. @@ -209,8 +205,7 @@ class ChannelOverwriteEntryInfo(BaseAuditLogEntryInfo, snowflakes.Unique): """The name of the role this overwrite targets, if it targets a role.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MessagePinEntryInfo(BaseAuditLogEntryInfo): """The extra information for message pin related audit log entries. @@ -290,8 +285,7 @@ async def fetch_message(self) -> messages.Message: return await self.app.rest.fetch_message(self.channel_id, self.message_id) -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MemberPruneEntryInfo(BaseAuditLogEntryInfo): """Extra information attached to guild prune log entries.""" @@ -302,8 +296,7 @@ class MemberPruneEntryInfo(BaseAuditLogEntryInfo): """The number of members who were removed by this prune.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MessageBulkDeleteEntryInfo(BaseAuditLogEntryInfo): """Extra information for the message bulk delete audit entry.""" @@ -311,8 +304,7 @@ class MessageBulkDeleteEntryInfo(BaseAuditLogEntryInfo): """The amount of messages that were deleted.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MessageDeleteEntryInfo(MessageBulkDeleteEntryInfo): """Extra information attached to the message delete audit entry.""" @@ -354,8 +346,7 @@ async def fetch_channel(self) -> channels.GuildTextChannel: return channel -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MemberDisconnectEntryInfo(BaseAuditLogEntryInfo): """Extra information for the voice chat member disconnect entry.""" @@ -363,8 +354,7 @@ class MemberDisconnectEntryInfo(BaseAuditLogEntryInfo): """The amount of members who were disconnected from voice in this entry.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MemberMoveEntryInfo(MemberDisconnectEntryInfo): """Extra information for the voice chat based member move entry.""" @@ -406,14 +396,11 @@ async def fetch_channel(self) -> channels.GuildVoiceChannel: return channel -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class AuditLogEntry(snowflakes.Unique): """Represents an entry in a guild's audit log.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -470,8 +457,7 @@ async def fetch_user(self) -> typing.Optional[users_.User]: return await self.app.rest.fetch_user(self.user_id) -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, repr=False, weakref_slot=False) class AuditLog(typing.Sequence[AuditLogEntry]): """Represents a guilds audit log.""" diff --git a/hikari/channels.py b/hikari/channels.py index 52a2f531ba..208016b676 100644 --- a/hikari/channels.py +++ b/hikari/channels.py @@ -54,7 +54,6 @@ from hikari import traits from hikari import undefined from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -111,8 +110,7 @@ class VideoQualityMode(int, enums.Enum): """Video quality will be set to 720p.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class ChannelFollow: """Relationship between a news channel and a subscriber channel. @@ -120,9 +118,7 @@ class ChannelFollow: to any "broadcast" announcements that the news channel creates. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """Return the client application that models may use for procedures. Returns @@ -252,8 +248,7 @@ class PermissionOverwriteType(int, enums.Enum): """A permission overwrite that targets a specific guild member.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PermissionOverwrite(snowflakes.Unique): """Represents permission overwrites for a channel or role in a channel. @@ -312,8 +307,7 @@ def unset(self) -> permissions.Permissions: return ~(self.allow | self.deny) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialChannel(snowflakes.Unique): """Channel representation for cases where further detail is not provided. @@ -321,9 +315,7 @@ class PartialChannel(snowflakes.Unique): not available from Discord. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -580,7 +572,7 @@ def trigger_typing(self) -> special_endpoints.TypingIndicator: return self.app.rest.trigger_typing(self.id) -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PrivateChannel(PartialChannel): """The base for anything that is a private (non-guild bound) channel.""" @@ -593,7 +585,7 @@ class PrivateChannel(PartialChannel): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class DMChannel(PrivateChannel, TextChannel): """Represents a direct message text channel that is between you and another user.""" @@ -609,7 +601,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__} with: {self.recipient}" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GroupDMChannel(PrivateChannel): """Represents a group direct message channel. @@ -682,7 +674,7 @@ def make_icon_url(self, *, ext: str = "png", size: int = 4096) -> typing.Optiona ) -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildChannel(PartialChannel): """The base for anything that is a guild channel.""" @@ -723,17 +715,13 @@ def shard_id(self) -> typing.Optional[int]: This may be `builtins.None` if the shard count is not known. """ - try: - shard_count = getattr(self.app, "shard_count") - assert isinstance(shard_count, int), f"shard_count attr was expected to be int, but got {shard_count}" - return snowflakes.calculate_shard_id(shard_count, self.guild_id) - except (TypeError, AttributeError, NameError): - pass + if isinstance(self.app, traits.ShardAware): + return snowflakes.calculate_shard_id(self.app, self.guild_id) return None -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildCategory(GuildChannel): """Represents a guild category channel. @@ -742,7 +730,7 @@ class GuildCategory(GuildChannel): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildTextChannel(GuildChannel, TextChannel): """Represents a guild text channel.""" @@ -777,7 +765,7 @@ class GuildTextChannel(GuildChannel, TextChannel): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildNewsChannel(GuildChannel, TextChannel): """Represents an news channel.""" @@ -801,7 +789,7 @@ class GuildNewsChannel(GuildChannel, TextChannel): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildStoreChannel(GuildChannel): """Represents a store channel. @@ -811,7 +799,7 @@ class GuildStoreChannel(GuildChannel): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildVoiceChannel(GuildChannel): """Represents a voice channel.""" @@ -836,7 +824,7 @@ class GuildVoiceChannel(GuildChannel): """The video quality mode for the voice channel.""" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildStageChannel(GuildChannel): """Represents a stage channel.""" diff --git a/hikari/config.py b/hikari/config.py index 25ab27a88f..9481f2925d 100644 --- a/hikari/config.py +++ b/hikari/config.py @@ -40,7 +40,6 @@ import attr import yarl -from hikari.internal import attr_extensions from hikari.internal import data_binding from hikari.internal import enums @@ -60,7 +59,6 @@ def _ssl_factory(value: typing.Union[bool, ssl_.SSLContext]) -> ssl_.SSLContext: return ssl -@attr_extensions.with_copy @attr.define(kw_only=True, repr=True, weakref_slot=False) class BasicAuthHeader: """An object that can be set as a producer for a basic auth header.""" @@ -113,7 +111,6 @@ def __str__(self) -> str: return self.header -@attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) class ProxySettings: """Settings for configuring an HTTP-based proxy.""" @@ -200,7 +197,6 @@ def all_headers(self) -> typing.Optional[data_binding.Headers]: return {**self.headers, _PROXY_AUTHENTICATION_HEADER: self.auth} -@attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) class HTTPTimeoutSettings: """Settings to control HTTP request timeouts.""" @@ -260,7 +256,6 @@ def _(self, attrib: attr.Attribute[typing.Optional[float]], value: typing.Option raise ValueError(f"HTTPTimeoutSettings.{attrib.name} must be None, or a POSITIVE float/int") -@attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) class HTTPSettings: """Settings to control HTTP clients.""" @@ -432,7 +427,6 @@ class CacheComponents(enums.Flag): """Fully enables the cache.""" -@attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) class CacheSettings: """Settings to control the cache.""" diff --git a/hikari/embeds.py b/hikari/embeds.py index aa0511bb68..67a9ecabcb 100644 --- a/hikari/embeds.py +++ b/hikari/embeds.py @@ -46,7 +46,6 @@ from hikari import errors from hikari import files from hikari import undefined -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: import concurrent.futures @@ -55,8 +54,7 @@ AsyncReaderT = typing.TypeVar("AsyncReaderT", bound=files.AsyncReader) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class EmbedResource(files.Resource[AsyncReaderT]): """A base type for any resource provided in an embed. @@ -111,7 +109,7 @@ def stream( return self.resource.stream(executor=executor, head_only=head_only) -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class EmbedResourceWithProxy(EmbedResource[AsyncReaderT]): """Resource with a corresponding proxied element.""" @@ -151,7 +149,6 @@ def proxy_filename(self) -> typing.Optional[str]: return self.proxy_resource.filename if self.proxy_resource else None -@attr_extensions.with_copy @attr.define(hash=False, kw_only=True, weakref_slot=False) class EmbedFooter: """Represents an embed footer.""" @@ -188,7 +185,7 @@ class EmbedImage(EmbedResourceWithProxy[AsyncReaderT]): """ -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class EmbedVideo(EmbedResourceWithProxy[AsyncReaderT]): """Represents an embed video. @@ -208,8 +205,7 @@ class yourself.** """The width of the video.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class EmbedProvider: """Represents an embed provider. @@ -230,7 +226,6 @@ class yourself.** """The URL of the provider.""" -@attr_extensions.with_copy @attr.define(hash=False, kw_only=True, weakref_slot=False) class EmbedAuthor: """Represents an author of an embed.""" @@ -248,7 +243,6 @@ class EmbedAuthor: """The author's icon, or `builtins.None` if not present.""" -@attr_extensions.with_copy @attr.define(hash=False, kw_only=True, weakref_slot=False) class EmbedField: """Represents a field in a embed.""" @@ -276,6 +270,7 @@ def is_inline(self, value: bool) -> None: self._inline = value +# TODO: separate into frozen received embed and embed builder class Embed: """Represents an embed.""" @@ -312,7 +307,7 @@ def from_received_embed( author: typing.Optional[EmbedAuthor], provider: typing.Optional[EmbedProvider], footer: typing.Optional[EmbedFooter], - fields: typing.Optional[typing.MutableSequence[EmbedField]], + fields: typing.Optional[typing.Sequence[EmbedField]], ) -> Embed: """Generate an embed from the given attributes. @@ -332,7 +327,7 @@ def from_received_embed( embed._author = author embed._provider = provider embed._footer = footer - embed._fields = fields + embed._fields = list(fields) if fields else None return embed def __init__( @@ -733,13 +728,9 @@ def set_author( if name is None and url is None and icon is None: self._author = None else: - self._author = EmbedAuthor() - self._author.name = name - self._author.url = url - if icon is not None: - self._author.icon = EmbedResourceWithProxy(resource=files.ensure_resource(icon)) - else: - self._author.icon = None + self._author = EmbedAuthor( + name=name, url=url, icon=EmbedResourceWithProxy(resource=files.ensure_resource(icon)) + ) return self def set_footer(self, *, text: typing.Optional[str], icon: typing.Optional[files.Resourceish] = None) -> Embed: @@ -787,12 +778,8 @@ def set_footer(self, *, text: typing.Optional[str], icon: typing.Optional[files. self._footer = None else: - self._footer = EmbedFooter() - self._footer.text = text - if icon is not None: - self._footer.icon = EmbedResourceWithProxy(resource=files.ensure_resource(icon)) - else: - self._footer.icon = None + icon = EmbedResourceWithProxy(resource=files.ensure_resource(icon)) if icon else None + self._footer = EmbedFooter(icon=icon, text=text) return self def set_image(self, image: typing.Optional[files.Resourceish] = None, /) -> Embed: diff --git a/hikari/emojis.py b/hikari/emojis.py index 7d1ac99225..420b01dfcc 100644 --- a/hikari/emojis.py +++ b/hikari/emojis.py @@ -35,7 +35,6 @@ from hikari import files from hikari import snowflakes from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import routes # import unicodedata @@ -49,7 +48,6 @@ _CUSTOM_EMOJI_REGEX: typing.Final[typing.Pattern[str]] = re.compile(r"<(?P[^:]*):(?P[^:]*):(?P\d+)>") -@attr.define(hash=True, kw_only=True, weakref_slot=False) class Emoji(files.WebResource, abc.ABC): """Base class for all emojis. @@ -59,6 +57,8 @@ class Emoji(files.WebResource, abc.ABC): `hikari.files.WebResource` would achieve this. """ + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def name(self) -> typing.Optional[str]: @@ -109,8 +109,7 @@ def parse(cls, string: str, /) -> Emoji: return UnicodeEmoji.parse(string) -@attr_extensions.with_copy -@attr.define(hash=True, weakref_slot=False) +@attr.frozen(hash=True, weakref_slot=False) class UnicodeEmoji(Emoji): """Represents a unicode emoji. @@ -245,8 +244,7 @@ def parse(cls, string: str, /) -> UnicodeEmoji: return cls(name=string) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class CustomEmoji(snowflakes.Unique, Emoji): """Represents a custom emoji. @@ -330,7 +328,7 @@ def parse(cls, string: str, /) -> CustomEmoji: raise ValueError("Expected an emoji ID or emoji mention") -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class KnownCustomEmoji(CustomEmoji): """Represents an emoji that is known from a guild the bot is in. @@ -338,9 +336,7 @@ class KnownCustomEmoji(CustomEmoji): _are_ part of. As a result, it contains a lot more information with it. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" guild_id: snowflakes.Snowflake = attr.field(eq=False, hash=False, repr=False) diff --git a/hikari/errors.py b/hikari/errors.py index a942061065..caa1898d87 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -57,7 +57,6 @@ import attr -from hikari.internal import attr_extensions from hikari.internal import enums if typing.TYPE_CHECKING: @@ -68,8 +67,7 @@ from hikari.internal import routes -@attr_extensions.with_copy -@attr.define(auto_exc=True, repr=False, init=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, init=False, weakref_slot=False) class HikariError(RuntimeError): """Base for an error raised by this API. @@ -80,8 +78,7 @@ class HikariError(RuntimeError): """ -@attr_extensions.with_copy -@attr.define(auto_exc=True, repr=False, init=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, init=False, weakref_slot=False) class HikariWarning(RuntimeWarning): """Base for a warning raised by this API. @@ -92,7 +89,7 @@ class HikariWarning(RuntimeWarning): """ -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class HikariInterrupt(KeyboardInterrupt, HikariError): """Exception raised when a kill signal is handled internally.""" @@ -103,7 +100,7 @@ class HikariInterrupt(KeyboardInterrupt, HikariError): """The signal name that was raised.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class ComponentNotRunningError(HikariError): """An exception thrown if trying to interact with a component that is not running.""" @@ -114,7 +111,7 @@ def __str__(self) -> str: return self.reason -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class UnrecognisedEntityError(HikariError): """An exception thrown when an unrecognised entity is found.""" @@ -125,7 +122,7 @@ def __str__(self) -> str: return self.reason -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class GatewayError(HikariError): """A base exception type for anything that can be thrown by the Gateway.""" @@ -171,7 +168,7 @@ def is_standard(self) -> bool: return bool((self.value // 1000) == 1) -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class GatewayConnectionError(GatewayError): """An exception thrown if a connection issue occurs.""" @@ -179,7 +176,7 @@ def __str__(self) -> str: return f"Failed to connect to server: {self.reason!r}" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class GatewayServerClosedConnectionError(GatewayError): """An exception raised when the server closes the connection.""" @@ -214,7 +211,7 @@ def __str__(self) -> str: return f"Server closed connection with code {self.code} ({self.reason})" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class HTTPError(HikariError): """Base exception raised if an HTTP error occurs while making a request.""" @@ -222,7 +219,7 @@ class HTTPError(HikariError): """The error message.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class HTTPClientClosedError(HTTPError): """Exception raised if an `aiohttp.ClientSession` was closed. @@ -407,7 +404,7 @@ class RESTErrorCode(int, enums.Enum): """API resource is currently overloaded. Try again a little later.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class HTTPResponseError(HTTPError): """Base exception for an erroneous HTTP response.""" @@ -449,7 +446,7 @@ def __str__(self) -> str: return f"{name_value}: '{body[:200]}{'...' if chomped else ''}' for {self.url}" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class ClientHTTPResponseError(HTTPResponseError): """Base exception for an erroneous HTTP response that is a client error. @@ -458,7 +455,7 @@ class ClientHTTPResponseError(HTTPResponseError): """ -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class BadRequestError(ClientHTTPResponseError): """Raised when you send an invalid request somehow.""" @@ -466,7 +463,7 @@ class BadRequestError(ClientHTTPResponseError): """The HTTP status code for the response.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class UnauthorizedError(ClientHTTPResponseError): """Raised when you are not authorized to access a specific resource.""" @@ -474,7 +471,7 @@ class UnauthorizedError(ClientHTTPResponseError): """The HTTP status code for the response.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class ForbiddenError(ClientHTTPResponseError): """Raised when you are not allowed to access a specific resource. @@ -487,7 +484,7 @@ class ForbiddenError(ClientHTTPResponseError): """The HTTP status code for the response.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class NotFoundError(ClientHTTPResponseError): """Raised when something is not found.""" @@ -495,7 +492,7 @@ class NotFoundError(ClientHTTPResponseError): """The HTTP status code for the response.""" -@attr.define(auto_exc=True, kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, kw_only=True, repr=False, weakref_slot=False) class RateLimitedError(ClientHTTPResponseError): """Raised when a non-global rate limit that cannot be handled occurs. @@ -529,7 +526,7 @@ def _(self) -> str: return f"You are being rate-limited for {self.retry_after:,} seconds on route {self.route}. Please slow down!" -@attr.define(auto_exc=True, kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, kw_only=True, repr=False, weakref_slot=False) class RateLimitTooLongError(HTTPError): """Internal error raised if the wait for a rate limit is too long. @@ -590,7 +587,7 @@ def __str__(self) -> str: return self.message -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class InternalServerError(HTTPResponseError): """Base exception for an erroneous HTTP response that is a server error. @@ -599,7 +596,7 @@ class InternalServerError(HTTPResponseError): """ -@attr.define(auto_exc=True, repr=False, init=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, init=False, weakref_slot=False) class MissingIntentWarning(HikariWarning): """Warning raised when subscribing to an event that cannot be fired. @@ -607,7 +604,7 @@ class MissingIntentWarning(HikariWarning): """ -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class BulkDeleteError(HikariError): """Exception raised when a bulk delete fails midway through a call. @@ -640,12 +637,12 @@ def __str__(self) -> str: return f"Error encountered when bulk deleting messages ({deleted}/{total} messages deleted)" -@attr.define(auto_exc=True, repr=False, init=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, init=False, weakref_slot=False) class VoiceError(HikariError): """Error raised when a problem occurs with the voice subsystem.""" -@attr.define(auto_exc=True, repr=False, weakref_slot=False) +@attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class MissingIntentError(HikariError, ValueError): """Error raised when you try to perform an action without an intent. diff --git a/hikari/events/base_events.py b/hikari/events/base_events.py index 9d562d41ae..50da80763b 100644 --- a/hikari/events/base_events.py +++ b/hikari/events/base_events.py @@ -41,7 +41,6 @@ from hikari import intents from hikari import traits from hikari.api import shard as gateway_shard -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: import types @@ -52,10 +51,11 @@ NO_RECURSIVE_THROW_ATTR: typing.Final[str] = "___norecursivethrow___" -@attr.define(kw_only=True, weakref_slot=False) class Event(abc.ABC): """Base event type that all Hikari events should subclass.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def app(self) -> traits.RESTAware: @@ -154,8 +154,7 @@ def is_no_recursive_throw_event(obj: typing.Union[T, typing.Type[T]]) -> bool: @no_recursive_throw() -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ExceptionEvent(Event, typing.Generic[FailedEventT]): """Event that is raised when another event handler raises an `Exception`. diff --git a/hikari/events/channel_events.py b/hikari/events/channel_events.py index 5b13ba6599..6cf0cca668 100644 --- a/hikari/events/channel_events.py +++ b/hikari/events/channel_events.py @@ -55,7 +55,6 @@ from hikari import traits from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: import datetime @@ -69,10 +68,11 @@ @base_events.requires_intents(intents.Intents.GUILDS, intents.Intents.DM_MESSAGES) -@attr.define(kw_only=True, weakref_slot=False) class ChannelEvent(shard_events.ShardEvent, abc.ABC): """Event base for any channel-bound event in guilds or private messages.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel_id(self) -> snowflakes.Snowflake: @@ -124,10 +124,11 @@ async def fetch_channel(self) -> channels.PartialChannel: @base_events.requires_intents(intents.Intents.GUILDS) -@attr.define(kw_only=True, weakref_slot=False) class GuildChannelEvent(ChannelEvent, abc.ABC): """Event base for any channel-bound event in guilds.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -246,10 +247,11 @@ async def fetch_channel(self) -> channels.GuildChannel: return channel -@attr.define(kw_only=True, weakref_slot=False) class DMChannelEvent(ChannelEvent, abc.ABC): """Event base for any channel-bound event in private messages.""" + __slots__: typing.Sequence[str] = () + async def fetch_channel(self) -> channels.PrivateChannel: """Perform an API call to fetch the details about this channel. @@ -292,10 +294,11 @@ async def fetch_channel(self) -> channels.PrivateChannel: @base_events.requires_intents(intents.Intents.GUILDS, intents.Intents.DM_MESSAGES) -@attr.define(kw_only=True, weakref_slot=False) class ChannelCreateEvent(ChannelEvent, abc.ABC): """Base event for any channel being created.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel(self) -> channels.PartialChannel: @@ -314,15 +317,14 @@ def channel_id(self) -> snowflakes.Snowflake: @base_events.requires_intents(intents.Intents.GUILDS) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class GuildChannelCreateEvent(GuildChannelEvent, ChannelCreateEvent): """Event fired when a guild channel is created.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel: channels.GuildChannel = attr.field(repr=True) @@ -341,10 +343,11 @@ def guild_id(self) -> snowflakes.Snowflake: @base_events.requires_intents(intents.Intents.GUILDS, intents.Intents.DM_MESSAGES) -@attr.define(kw_only=True, weakref_slot=False) class ChannelUpdateEvent(ChannelEvent, abc.ABC): """Base event for any channel being updated.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel(self) -> channels.PartialChannel: @@ -363,15 +366,14 @@ def channel_id(self) -> snowflakes.Snowflake: @base_events.requires_intents(intents.Intents.GUILDS) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class GuildChannelUpdateEvent(GuildChannelEvent, ChannelUpdateEvent): """Event fired when a guild channel is edited.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_channel: typing.Optional[channels.GuildChannel] = attr.field(repr=True) @@ -396,10 +398,11 @@ def guild_id(self) -> snowflakes.Snowflake: @base_events.requires_intents(intents.Intents.GUILDS, intents.Intents.DM_MESSAGES) -@attr.define(kw_only=True, weakref_slot=False) class ChannelDeleteEvent(ChannelEvent, abc.ABC): """Base event for any channel being deleted.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel(self) -> channels.PartialChannel: @@ -423,15 +426,14 @@ async def fetch_channel(self) -> typing.NoReturn: @base_events.requires_intents(intents.Intents.GUILDS) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class GuildChannelDeleteEvent(GuildChannelEvent, ChannelDeleteEvent): """Event fired when a guild channel is deleted.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel: channels.GuildChannel = attr.field(repr=True) @@ -455,10 +457,11 @@ async def fetch_channel(self) -> typing.NoReturn: @base_events.requires_intents(intents.Intents.DM_MESSAGES, intents.Intents.GUILDS) -@attr.define(kw_only=True, weakref_slot=False) class PinsUpdateEvent(ChannelEvent, abc.ABC): """Base event fired when a message is pinned/unpinned in a channel.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def last_pin_timestamp(self) -> typing.Optional[datetime.datetime]: @@ -498,15 +501,14 @@ async def fetch_pins(self) -> typing.Sequence[messages.Message]: @base_events.requires_intents(intents.Intents.GUILDS) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class GuildPinsUpdateEvent(PinsUpdateEvent, GuildChannelEvent): """Event fired when a message is pinned/unpinned in a guild channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() @@ -576,15 +578,14 @@ async def fetch_channel(self) -> channels.GuildTextChannel: @base_events.requires_intents(intents.Intents.DM_MESSAGES) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class DMPinsUpdateEvent(PinsUpdateEvent, DMChannelEvent): """Event fired when a message is pinned/unpinned in a private channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() @@ -628,10 +629,11 @@ async def fetch_channel(self) -> channels.DMChannel: @base_events.requires_intents(intents.Intents.GUILD_INVITES) -@attr.define(kw_only=True, weakref_slot=False) class InviteEvent(GuildChannelEvent, abc.ABC): """Base event type for guild invite updates.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def code(self) -> str: @@ -675,15 +677,14 @@ async def fetch_invite(self) -> invites.Invite: @base_events.requires_intents(intents.Intents.GUILD_INVITES) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class InviteCreateEvent(InviteEvent): """Event fired when an invite is created in a channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. invite: invites.InviteWithMetadata = attr.field() @@ -714,15 +715,14 @@ def code(self) -> str: @base_events.requires_intents(intents.Intents.GUILD_INVITES) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class InviteDeleteEvent(InviteEvent): """Event fired when an invite is deleted from a channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() @@ -747,8 +747,7 @@ async def fetch_invite(self) -> typing.NoReturn: @base_events.requires_intents(intents.Intents.GUILD_WEBHOOKS) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class WebhookUpdateEvent(GuildChannelEvent): """Event fired when a webhook is created/updated/deleted in a channel. @@ -758,10 +757,10 @@ class WebhookUpdateEvent(GuildChannelEvent): the channel manually beforehand. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() diff --git a/hikari/events/guild_events.py b/hikari/events/guild_events.py index 4576873c55..1eff645ec0 100644 --- a/hikari/events/guild_events.py +++ b/hikari/events/guild_events.py @@ -50,7 +50,6 @@ from hikari import intents from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import channels as channels_ @@ -64,13 +63,14 @@ from hikari.api import shard as gateway_shard -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents( intents.Intents.GUILDS, intents.Intents.GUILD_BANS, intents.Intents.GUILD_EMOJIS, intents.Intents.GUILD_PRESENCES ) class GuildEvent(shard_events.ShardEvent, abc.ABC): """Event base for any guild-bound event.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -119,7 +119,6 @@ async def fetch_guild_preview(self) -> guilds.GuildPreview: return await self.app.rest.fetch_guild_preview(self.guild_id) -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class GuildVisibilityEvent(GuildEvent, abc.ABC): """Event base for any event that changes the visibility of a guild. @@ -130,9 +129,10 @@ class GuildVisibilityEvent(GuildEvent, abc.ABC): the user joins a new guild. """ + __slots__: typing.Sequence[str] = () + -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class GuildAvailableEvent(GuildVisibilityEvent): """Event fired when a guild becomes available. @@ -145,10 +145,10 @@ class GuildAvailableEvent(GuildVisibilityEvent): event models. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild: guilds.GatewayGuild = attr.field() @@ -234,8 +234,7 @@ def guild_id(self) -> snowflakes.Snowflake: return self.guild.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class GuildLeaveEvent(GuildVisibilityEvent): """Event fired when the bot is banned/kicked/leaves a guild. @@ -243,10 +242,10 @@ class GuildLeaveEvent(GuildVisibilityEvent): This will also fire if the guild was deleted. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -258,32 +257,30 @@ async def fetch_guild(self) -> typing.NoReturn: ... -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class GuildUnavailableEvent(GuildVisibilityEvent): """Event fired when a guild becomes unavailable because of an outage.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class GuildUpdateEvent(GuildEvent): """Event fired when an existing guild is updated.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_guild: typing.Optional[guilds.Guild] = attr.field() @@ -325,11 +322,12 @@ def guild_id(self) -> snowflakes.Snowflake: return self.guild.id -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_BANS) class BanEvent(GuildEvent, abc.ABC): """Event base for any guild ban or unban.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def user(self) -> users.User: @@ -351,16 +349,15 @@ async def fetch_user(self) -> users.User: """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_BANS) class BanCreateEvent(BanEvent): """Event that is fired when a user is banned from a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -383,16 +380,15 @@ async def fetch_ban(self) -> guilds.GuildMemberBan: return await self.app.rest.fetch_ban(self.guild_id, self.user) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_BANS) class BanDeleteEvent(BanEvent): """Event that is fired when a user is unbanned from a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -402,16 +398,15 @@ class BanDeleteEvent(BanEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_EMOJIS) class EmojisUpdateEvent(GuildEvent): """Event that is fired when the emojis in a guild are updated.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -443,11 +438,12 @@ async def fetch_emojis(self) -> typing.Sequence[emojis_.KnownCustomEmoji]: return await self.app.rest.fetch_guild_emojis(self.guild_id) -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_INTEGRATIONS) class IntegrationEvent(GuildEvent, abc.ABC): """Event base for any integration related events.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def application_id(self) -> typing.Optional[snowflakes.Snowflake]: @@ -489,16 +485,15 @@ async def fetch_integrations(self) -> typing.Sequence[guilds.Integration]: return await self.app.rest.fetch_integrations(self.guild_id) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_INTEGRATIONS) class IntegrationCreateEvent(IntegrationEvent): """Event that is fired when an integration is created in a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. integration: guilds.Integration = attr.field() @@ -520,16 +515,15 @@ def id(self) -> snowflakes.Snowflake: return self.integration.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_INTEGRATIONS) class IntegrationDeleteEvent(IntegrationEvent): """Event that is fired when an integration is deleted in a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. application_id: typing.Optional[snowflakes.Snowflake] = attr.field() @@ -542,16 +536,15 @@ class IntegrationDeleteEvent(IntegrationEvent): # <> -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_INTEGRATIONS) class IntegrationUpdateEvent(IntegrationEvent): """Event that is fired when an integration is updated in a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. integration: guilds.Integration = attr.field() @@ -573,8 +566,7 @@ def id(self) -> snowflakes.Snowflake: return self.integration.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_PRESENCES) class PresenceUpdateEvent(shard_events.ShardEvent): """Event fired when a user in a guild updates their presence in a guild. @@ -590,10 +582,10 @@ class PresenceUpdateEvent(shard_events.ShardEvent): shards that saw the presence update. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_presence: typing.Optional[presences_.MemberPresence] = attr.field() diff --git a/hikari/events/lifetime_events.py b/hikari/events/lifetime_events.py index 4538b935dc..44e147a4af 100644 --- a/hikari/events/lifetime_events.py +++ b/hikari/events/lifetime_events.py @@ -35,14 +35,12 @@ import attr from hikari.events import base_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import traits -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class StartingEvent(base_events.Event): """Event that is triggered before the application connects to discord. @@ -60,12 +58,11 @@ class StartingEvent(base_events.Event): should consider using `StartedEvent` instead. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class StartedEvent(base_events.Event): """Event that is triggered after the application has started. @@ -77,12 +74,11 @@ class StartedEvent(base_events.Event): consider using `StartingEvent` instead. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class StoppingEvent(base_events.Event): """Event that is triggered as soon as the application is requested to close. @@ -102,12 +98,11 @@ class StoppingEvent(base_events.Event): should consider using `StoppedEvent` instead. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class StoppedEvent(base_events.Event): """Event that is triggered once the application has disconnected. @@ -126,5 +121,5 @@ class StoppedEvent(base_events.Event): `StoppingEvent` instead. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. diff --git a/hikari/events/member_events.py b/hikari/events/member_events.py index 97bbe52f07..eba6bcf907 100644 --- a/hikari/events/member_events.py +++ b/hikari/events/member_events.py @@ -39,7 +39,6 @@ from hikari import traits from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import guilds @@ -48,11 +47,12 @@ from hikari.api import shard as gateway_shard -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MEMBERS) class MemberEvent(shard_events.ShardEvent, abc.ABC): """Event base for any events that concern guild members.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -104,16 +104,15 @@ def guild(self) -> typing.Optional[guilds.GatewayGuild]: return self.app.cache.get_available_guild(self.guild_id) or self.app.cache.get_unavailable_guild(self.guild_id) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MEMBERS) class MemberCreateEvent(MemberEvent): """Event that is fired when a member joins a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. member: guilds.Member = attr.field() @@ -136,8 +135,7 @@ def user(self) -> users.User: return self.member.user -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MEMBERS) class MemberUpdateEvent(MemberEvent): """Event that is fired when a member is updated in a guild. @@ -145,10 +143,10 @@ class MemberUpdateEvent(MemberEvent): This may occur if roles are amended, or if the nickname is changed. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_member: typing.Optional[guilds.Member] = attr.field() @@ -177,16 +175,15 @@ def user(self) -> users.User: return self.member.user -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MEMBERS) class MemberDeleteEvent(MemberEvent): """Event fired when a member is kicked from or leaves a guild.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() diff --git a/hikari/events/message_events.py b/hikari/events/message_events.py index 930c2a09f3..4081580f94 100644 --- a/hikari/events/message_events.py +++ b/hikari/events/message_events.py @@ -48,7 +48,6 @@ from hikari import traits from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import embeds as embeds_ @@ -59,11 +58,12 @@ from hikari.api import shard as shard_ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES, intents.Intents.GUILD_MESSAGES) class MessageEvent(shard_events.ShardEvent, abc.ABC): """Any event that concerns manipulation of messages.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel_id(self) -> snowflakes.Snowflake: @@ -87,11 +87,12 @@ def message_id(self) -> snowflakes.Snowflake: """ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES, intents.Intents.GUILD_MESSAGES) class MessageCreateEvent(MessageEvent, abc.ABC): """Event that is fired when a message is created.""" + __slots__: typing.Sequence[str] = () + @property def author(self) -> users.User: """User that sent the message. @@ -201,8 +202,7 @@ def message_id(self) -> snowflakes.Snowflake: return self.message.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) class GuildMessageCreateEvent(MessageCreateEvent): """Event that is fired when a message is created within a guild. @@ -210,13 +210,13 @@ class GuildMessageCreateEvent(MessageCreateEvent): This contains the full message in the internal `message` attribute. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> message: messages.Message = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> @property @@ -285,8 +285,7 @@ def guild_id(self) -> snowflakes.Snowflake: return guild_id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES) class DMMessageCreateEvent(MessageCreateEvent): """Event that is fired when a message is created within a DM. @@ -294,17 +293,16 @@ class DMMessageCreateEvent(MessageCreateEvent): This contains the full message in the internal `message` attribute. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> message: messages.Message = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES, intents.Intents.GUILD_MESSAGES) class MessageUpdateEvent(MessageEvent, abc.ABC): """Event that is fired when a message is updated. @@ -314,6 +312,8 @@ class MessageUpdateEvent(MessageEvent, abc.ABC): due to Discord limitations. """ + __slots__: typing.Sequence[str] = () + @property def author(self) -> typing.Optional[users.User]: """User that sent the message. @@ -448,8 +448,7 @@ def message_id(self) -> snowflakes.Snowflake: return self.message.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) class GuildMessageUpdateEvent(MessageUpdateEvent): """Event that is fired when a message is updated in a guild. @@ -459,7 +458,7 @@ class GuildMessageUpdateEvent(MessageUpdateEvent): due to Discord limitations. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> old_message: typing.Optional[messages.PartialMessage] = attr.field() @@ -471,7 +470,7 @@ class GuildMessageUpdateEvent(MessageUpdateEvent): message: messages.PartialMessage = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> @property @@ -556,8 +555,7 @@ def guild_id(self) -> snowflakes.Snowflake: return guild_id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES) class DMMessageUpdateEvent(MessageUpdateEvent): """Event that is fired when a message is updated in a DM. @@ -567,7 +565,7 @@ class DMMessageUpdateEvent(MessageUpdateEvent): due to Discord limitations. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> old_message: typing.Optional[messages.PartialMessage] = attr.field() @@ -579,11 +577,10 @@ class DMMessageUpdateEvent(MessageUpdateEvent): message: messages.PartialMessage = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGES, intents.Intents.DM_MESSAGES) class MessageDeleteEvent(MessageEvent, abc.ABC): """Special event that is triggered when one or more messages get deleted. @@ -596,6 +593,8 @@ class MessageDeleteEvent(MessageEvent, abc.ABC): `is_bulk` attribute. """ + __slots__: typing.Sequence[str] = () + @property def message_id(self) -> snowflakes.Snowflake: """Get the ID of the first deleted message. @@ -638,8 +637,7 @@ def is_bulk(self) -> bool: """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGES) class GuildMessageDeleteEvent(MessageDeleteEvent): """Event that is triggered if messages are deleted in a guild. @@ -653,7 +651,7 @@ class GuildMessageDeleteEvent(MessageDeleteEvent): checking the `is_bulk` attribute. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> channel_id: snowflakes.Snowflake = attr.field() @@ -674,7 +672,7 @@ class GuildMessageDeleteEvent(MessageDeleteEvent): message_ids: typing.AbstractSet[snowflakes.Snowflake] = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> @property @@ -720,8 +718,7 @@ def guild(self) -> typing.Optional[guilds.GatewayGuild]: return self.app.cache.get_guild(self.guild_id) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGES) class DMMessageDeleteEvent(MessageDeleteEvent): """Event that is triggered if messages are deleted in a DM. @@ -737,7 +734,7 @@ class DMMessageDeleteEvent(MessageDeleteEvent): `is_bulk` attribute. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <> channel_id: snowflakes.Snowflake = attr.field() @@ -749,5 +746,5 @@ class DMMessageDeleteEvent(MessageDeleteEvent): message_ids: typing.AbstractSet[snowflakes.Snowflake] = attr.field() # <> - shard: shard_.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: shard_.GatewayShard = attr.field() # <> diff --git a/hikari/events/reaction_events.py b/hikari/events/reaction_events.py index 7796ed83c1..36b432ebb9 100644 --- a/hikari/events/reaction_events.py +++ b/hikari/events/reaction_events.py @@ -50,7 +50,6 @@ from hikari import intents from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import emojis @@ -60,11 +59,12 @@ from hikari.api import shard as gateway_shard -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS, intents.Intents.DM_MESSAGE_REACTIONS) class ReactionEvent(shard_events.ShardEvent, abc.ABC): """Event base for any message reaction event.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel_id(self) -> snowflakes.Snowflake: @@ -88,11 +88,12 @@ def message_id(self) -> snowflakes.Snowflake: """ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS) class GuildReactionEvent(ReactionEvent, abc.ABC): """Event base for any reaction-bound event in guild messages.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -105,17 +106,19 @@ def guild_id(self) -> snowflakes.Snowflake: """ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGE_REACTIONS) class DMReactionEvent(ReactionEvent, abc.ABC): """Event base for any reaction-bound event in private messages.""" + __slots__: typing.Sequence[str] = () + -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS, intents.Intents.DM_MESSAGE_REACTIONS) class ReactionAddEvent(ReactionEvent, abc.ABC): """Event base for any reaction that is added to a message.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def user_id(self) -> snowflakes.Snowflake: @@ -140,11 +143,12 @@ def emoji(self) -> emojis.Emoji: """ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS, intents.Intents.DM_MESSAGE_REACTIONS) class ReactionDeleteEvent(ReactionEvent, abc.ABC): """Event base for any single reaction that is removed from a message.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def user_id(self) -> snowflakes.Snowflake: @@ -170,17 +174,19 @@ def emoji(self) -> emojis.Emoji: """ -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS, intents.Intents.DM_MESSAGE_REACTIONS) class ReactionDeleteAllEvent(ReactionEvent, abc.ABC): """Event base fired when all reactions are removed from a message.""" + __slots__: typing.Sequence[str] = () + -@attr.define(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS, intents.Intents.DM_MESSAGE_REACTIONS) class ReactionDeleteEmojiEvent(ReactionEvent, abc.ABC): """Event base fired when all reactions are removed for one emoji.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def emoji(self) -> emojis.Emoji: @@ -195,16 +201,15 @@ def emoji(self) -> emojis.Emoji: """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS) class GuildReactionAddEvent(GuildReactionEvent, ReactionAddEvent): """Event fired when a reaction is added to a guild message.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. member: guilds.Member = attr.field() @@ -237,16 +242,15 @@ def user_id(self) -> snowflakes.Snowflake: return self.member.user.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS) class GuildReactionDeleteEvent(GuildReactionEvent, ReactionDeleteEvent): """Event fired when a reaction is removed from a guild message.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. user_id: snowflakes.Snowflake = attr.field() @@ -265,16 +269,15 @@ class GuildReactionDeleteEvent(GuildReactionEvent, ReactionDeleteEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS) class GuildReactionDeleteEmojiEvent(GuildReactionEvent, ReactionDeleteEmojiEvent): """Event fired when an emoji is removed from a guild message's reactions.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -290,16 +293,15 @@ class GuildReactionDeleteEmojiEvent(GuildReactionEvent, ReactionDeleteEmojiEvent # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_REACTIONS) class GuildReactionDeleteAllEvent(GuildReactionEvent, ReactionDeleteAllEvent): """Event fired when all of a guild message's reactions are removed.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() @@ -312,16 +314,15 @@ class GuildReactionDeleteAllEvent(GuildReactionEvent, ReactionDeleteAllEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGE_REACTIONS) class DMReactionAddEvent(DMReactionEvent, ReactionAddEvent): """Event fired when a reaction is added to a private message.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. user_id: snowflakes.Snowflake = attr.field() @@ -337,16 +338,15 @@ class DMReactionAddEvent(DMReactionEvent, ReactionAddEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGE_REACTIONS) class DMReactionDeleteEvent(DMReactionEvent, ReactionDeleteEvent): """Event fired when a reaction is removed from a private message.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. user_id: snowflakes.Snowflake = attr.field() @@ -362,16 +362,15 @@ class DMReactionDeleteEvent(DMReactionEvent, ReactionDeleteEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGE_REACTIONS) class DMReactionDeleteEmojiEvent(DMReactionEvent, ReactionDeleteEmojiEvent): """Event fired when an emoji is removed from a private message's reactions.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() @@ -384,16 +383,15 @@ class DMReactionDeleteEmojiEvent(DMReactionEvent, ReactionDeleteEmojiEvent): # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.DM_MESSAGE_REACTIONS) class DMReactionDeleteAllEvent(DMReactionEvent, ReactionDeleteAllEvent): """Event fired when all of a private message's reactions are removed.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() diff --git a/hikari/events/role_events.py b/hikari/events/role_events.py index c0888e7e4f..6e377368c0 100644 --- a/hikari/events/role_events.py +++ b/hikari/events/role_events.py @@ -38,7 +38,6 @@ from hikari import intents from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import guilds @@ -51,6 +50,8 @@ class RoleEvent(shard_events.ShardEvent, abc.ABC): """Event base for any event that involves guild roles.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -74,16 +75,15 @@ def role_id(self) -> snowflakes.Snowflake: """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class RoleCreateEvent(RoleEvent): """Event fired when a role is created.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. role: guilds.Role = attr.field() @@ -106,16 +106,15 @@ def role_id(self) -> snowflakes.Snowflake: return self.role.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class RoleUpdateEvent(RoleEvent): """Event fired when a role is updated.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_role: typing.Optional[guilds.Role] = attr.field() @@ -144,16 +143,15 @@ def role_id(self) -> snowflakes.Snowflake: return self.role.id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) @base_events.requires_intents(intents.Intents.GUILDS) class RoleDeleteEvent(RoleEvent): """Event fired when a role is deleted.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field() diff --git a/hikari/events/shard_events.py b/hikari/events/shard_events.py index a25e103e8c..0398a97939 100644 --- a/hikari/events/shard_events.py +++ b/hikari/events/shard_events.py @@ -41,7 +41,6 @@ import attr from hikari.events import base_events -from hikari.internal import attr_extensions from hikari.internal import collections if typing.TYPE_CHECKING: @@ -57,6 +56,8 @@ class ShardEvent(base_events.Event, abc.ABC): """Base class for any event that was shard-specific.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def shard(self) -> gateway_shard.GatewayShard: @@ -69,8 +70,7 @@ def shard(self) -> gateway_shard.GatewayShard: """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ShardPayload(ShardEvent): """Event fired for most shard events with their raw payload. @@ -79,10 +79,10 @@ class ShardPayload(ShardEvent): Discord and not artificial events like the `ShardStateEvent` events. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. name: str = attr.field() @@ -98,40 +98,39 @@ class ShardStateEvent(ShardEvent, abc.ABC): This currently wraps connection/disconnection/ready/resumed events only. """ + __slots__: typing.Sequence[str] = () + -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ShardConnectedEvent(ShardStateEvent): """Event fired when a shard connects.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ShardDisconnectedEvent(ShardStateEvent): """Event fired when a shard disconnects.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ShardReadyEvent(ShardStateEvent): """Event fired when a shard declares it is ready.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. actual_gateway_version: int = attr.field(repr=True) @@ -192,27 +191,25 @@ class ShardReadyEvent(ShardStateEvent): """ -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class ShardResumedEvent(ShardStateEvent): """Event fired when a shard resumes an existing session.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class MemberChunkEvent(ShardEvent, typing.Sequence["guilds.Member"]): """Event fired when a member chunk payload is received on a gateway shard.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field(repr=True) diff --git a/hikari/events/typing_events.py b/hikari/events/typing_events.py index bf238019b4..724aa879a9 100644 --- a/hikari/events/typing_events.py +++ b/hikari/events/typing_events.py @@ -40,7 +40,6 @@ from hikari.api import special_endpoints from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: import datetime @@ -55,6 +54,8 @@ class TypingEvent(shard_events.ShardEvent, abc.ABC): """Base event fired when a user begins typing in a channel.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def channel_id(self) -> snowflakes.Snowflake: @@ -132,15 +133,14 @@ def trigger_typing(self) -> special_endpoints.TypingIndicator: @base_events.requires_intents(intents.Intents.GUILD_MESSAGE_TYPING) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class GuildTypingEvent(TypingEvent): """Event fired when a user starts typing in a guild channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() @@ -255,15 +255,14 @@ async def fetch_user(self) -> guilds.Member: @base_events.requires_intents(intents.Intents.DM_MESSAGES) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class DMTypingEvent(TypingEvent): """Event fired when a user starts typing in a guild channel.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. channel_id: snowflakes.Snowflake = attr.field() diff --git a/hikari/events/user_events.py b/hikari/events/user_events.py index 22325109f8..60fbe8fac8 100644 --- a/hikari/events/user_events.py +++ b/hikari/events/user_events.py @@ -30,7 +30,6 @@ import attr from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import traits @@ -38,15 +37,14 @@ from hikari.api import shard as gateway_shard -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class OwnUserUpdateEvent(shard_events.ShardEvent): """Event fired when the account user is updated.""" - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_user: typing.Optional[users.OwnUser] = attr.field() diff --git a/hikari/events/voice_events.py b/hikari/events/voice_events.py index b0c9098790..bc5e593dcf 100644 --- a/hikari/events/voice_events.py +++ b/hikari/events/voice_events.py @@ -38,7 +38,6 @@ from hikari import intents from hikari.events import base_events from hikari.events import shard_events -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: from hikari import snowflakes @@ -47,10 +46,11 @@ from hikari.api import shard as gateway_shard -@attr.define(kw_only=True, weakref_slot=False) class VoiceEvent(shard_events.ShardEvent, abc.ABC): """Base for any voice-related event.""" + __slots__: typing.Sequence[str] = () + @property @abc.abstractmethod def guild_id(self) -> snowflakes.Snowflake: @@ -64,8 +64,7 @@ def guild_id(self) -> snowflakes.Snowflake: @base_events.requires_intents(intents.Intents.GUILD_VOICE_STATES) -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class VoiceStateUpdateEvent(VoiceEvent): """Event fired when a user changes their voice state. @@ -75,10 +74,10 @@ class VoiceStateUpdateEvent(VoiceEvent): to connect to the voice gateway to stream audio or video content. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. old_state: typing.Optional[voices.VoiceState] = attr.field(repr=True) @@ -102,8 +101,7 @@ def guild_id(self) -> snowflakes.Snowflake: return self.state.guild_id -@attr_extensions.with_copy -@attr.define(kw_only=True, weakref_slot=False) +@attr.frozen(kw_only=True, weakref_slot=False) class VoiceServerUpdateEvent(VoiceEvent): """Event fired when a voice server is changed. @@ -111,10 +109,10 @@ class VoiceServerUpdateEvent(VoiceEvent): falls over to a new server. """ - app: traits.RESTAware = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + app: traits.RESTAware = attr.field() # <>. - shard: gateway_shard.GatewayShard = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + shard: gateway_shard.GatewayShard = attr.field() # <>. guild_id: snowflakes.Snowflake = attr.field(repr=True) diff --git a/hikari/guilds.py b/hikari/guilds.py index 44506cf3a4..d3cf7e04d0 100644 --- a/hikari/guilds.py +++ b/hikari/guilds.py @@ -64,7 +64,6 @@ from hikari import undefined from hikari import urls from hikari import users -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -246,14 +245,11 @@ class GuildVerificationLevel(int, enums.Enum): """Must have a verified phone number.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class GuildWidget: """Represents a guild widget.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" channel_id: typing.Optional[snowflakes.Snowflake] = attr.field(repr=True) @@ -263,8 +259,7 @@ class GuildWidget: """Whether this embed is enabled.""" -@attr_extensions.with_copy -@attr.define(eq=False, hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(eq=False, hash=False, kw_only=True, weakref_slot=False) class Member(users.User): """Used to represent a guild bound member.""" @@ -365,10 +360,6 @@ def flags(self) -> users.UserFlag: def id(self) -> snowflakes.Snowflake: return self.user.id - @id.setter - def id(self, value: snowflakes.Snowflake) -> None: - raise TypeError("Cannot mutate the ID of a member") - @property def is_bot(self) -> bool: return self.user.is_bot @@ -550,14 +541,11 @@ def __eq__(self, other: object) -> bool: return self.user == other -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialRole(snowflakes.Unique): """Represents a partial guild bound Role object.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -570,7 +558,7 @@ def __str__(self) -> str: return self.name -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Role(PartialRole): """Represents a guild bound Role object.""" @@ -650,8 +638,7 @@ class IntegrationExpireBehaviour(int, enums.Enum): """Kick the subscriber.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class IntegrationAccount: """An account that's linked to an integration.""" @@ -666,8 +653,7 @@ def __str__(self) -> str: # This is here rather than in applications.py to avoid circular imports -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialApplication(snowflakes.Unique): """A partial representation of a Discord application.""" @@ -735,8 +721,7 @@ def make_icon_url(self, *, ext: str = "png", size: int = 4096) -> typing.Optiona ) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class IntegrationApplication(PartialApplication): """An application that's linked to an integration.""" @@ -744,8 +729,7 @@ class IntegrationApplication(PartialApplication): """The bot associated with this application.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialIntegration(snowflakes.Unique): """A partial representation of an integration, found in audit logs.""" @@ -765,7 +749,7 @@ def __str__(self) -> str: return self.name -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Integration(PartialIntegration): """Represents a guild integration object.""" @@ -822,8 +806,7 @@ class Integration(PartialIntegration): """ -@attr_extensions.with_copy -@attr.define(hash=False, weakref_slot=False) +@attr.frozen(hash=False, weakref_slot=False) class WelcomeChannel: """Used to represent channels on guild welcome screens.""" @@ -837,8 +820,7 @@ class WelcomeChannel: """The emoji shown in the welcome screen channel if set else `builtins.None`.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class WelcomeScreen: """Used to represent guild welcome screens on Discord.""" @@ -849,8 +831,7 @@ class WelcomeScreen: """An array of up to 5 of the channels shown in the welcome screen.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class GuildMemberBan: """Used to represent guild bans.""" @@ -861,14 +842,11 @@ class GuildMemberBan: """The object of the user this ban targets.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialGuild(snowflakes.Unique): """Base object for any partial guild objects.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -946,7 +924,7 @@ def make_icon_url(self, *, ext: typing.Optional[str] = None, size: int = 4096) - ) -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GuildPreview(PartialGuild): """A preview of a guild with the `GuildFeature.DISCOVERABLE` feature.""" @@ -1050,7 +1028,7 @@ def make_splash_url(self, *, ext: str = "png", size: int = 4096) -> typing.Optio ) -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Guild(PartialGuild, abc.ABC): """A representation of a guild on Discord.""" @@ -1339,7 +1317,7 @@ def get_role(self, role: snowflakes.SnowflakeishOr[PartialRole]) -> typing.Optio """Get a role from the cache by it's ID.""" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class RESTGuild(Guild): """Guild specialization that is sent via the REST API only.""" @@ -1392,7 +1370,7 @@ def get_role(self, role: snowflakes.SnowflakeishOr[PartialRole]) -> typing.Optio return self._roles.get(snowflakes.Snowflake(role)) -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class GatewayGuild(Guild): """Guild specialization that is sent via the gateway only.""" diff --git a/hikari/impl/cache.py b/hikari/impl/cache.py index 4fe5ad8759..57da878958 100644 --- a/hikari/impl/cache.py +++ b/hikari/impl/cache.py @@ -26,7 +26,6 @@ __all__: typing.List[str] = ["CacheImpl"] -import copy import logging import typing @@ -317,7 +316,7 @@ def _get_guild( if not guild_record or not guild_record.guild or guild_record.is_available is not availability: return None - return copy.copy(guild_record.guild) + return guild_record.guild def get_guild( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], / @@ -326,7 +325,7 @@ def get_guild( return None guild_record = self._guild_entries.get(snowflakes.Snowflake(guild)) - return copy.copy(guild_record.guild) if guild_record and guild_record.guild else None + return guild_record.guild if guild_record and guild_record.guild else None def get_available_guild( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], / @@ -376,7 +375,7 @@ def set_guild(self, guild: guilds.GatewayGuild, /) -> None: return None guild_record = self._get_or_create_guild_record(guild.id) - guild_record.guild = copy.copy(guild) + guild_record.guild = guild guild_record.is_available = True def set_guild_availability( @@ -395,15 +394,8 @@ def update_guild( if not self._is_cache_enabled_for(config.CacheComponents.GUILDS): return None, None - guild = copy.copy(guild) cached_guild = self.get_guild(guild.id) - # We have to manually update these because Inconsistency is Discord's middle name. - if cached_guild: - guild.member_count = cached_guild.member_count - guild.joined_at = cached_guild.joined_at - guild.is_large = cached_guild.is_large - self.set_guild(guild) return cached_guild, self.get_guild(guild.id) @@ -464,13 +456,10 @@ def get_guild_channel( if not self._is_cache_enabled_for(config.CacheComponents.GUILD_CHANNELS): return None - channel = self._guild_channel_entries.get(snowflakes.Snowflake(channel)) - return cache_utility.copy_guild_channel(channel) if channel else None + return self._guild_channel_entries.get(snowflakes.Snowflake(channel)) def get_guild_channels_view(self) -> cache.CacheView[snowflakes.Snowflake, channels.GuildChannel]: - return cache_utility.CacheMappingView( - self._guild_channel_entries.freeze(), builder=cache_utility.copy_guild_channel # type: ignore[type-var] - ) + return cache_utility.CacheMappingView(self._guild_channel_entries.freeze()) def get_guild_channels_view_for_guild( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], / @@ -497,15 +486,13 @@ def sorter(args: typing.Tuple[snowflakes.Snowflake, channels.GuildChannel]) -> t return parent_position, 1, channel.position cached_channels = dict(sorted(cached_channels.items(), key=sorter)) - return cache_utility.CacheMappingView( - cached_channels, builder=cache_utility.copy_guild_channel # type: ignore[type-var] - ) + return cache_utility.CacheMappingView(cached_channels) def set_guild_channel(self, channel: channels.GuildChannel, /) -> None: if not self._is_cache_enabled_for(config.CacheComponents.GUILD_CHANNELS): return None - self._guild_channel_entries[channel.id] = cache_utility.copy_guild_channel(channel) + self._guild_channel_entries[channel.id] = channel guild_record = self._get_or_create_guild_record(channel.guild_id) if guild_record.channels is None: @@ -722,10 +709,10 @@ def delete_me(self) -> typing.Optional[users.OwnUser]: return cached_user def get_me(self) -> typing.Optional[users.OwnUser]: - return copy.copy(self._me) + return self._me def set_me(self, user: users.OwnUser, /) -> None: - self._me = copy.copy(user) + self._me = user def update_me( self, user: users.OwnUser, / @@ -860,7 +847,7 @@ def get_members_view( for guild_id, view in self._guild_entries.items() if view.members } - return cache_utility.Cache3DMappingView(views) + return cache_utility.CacheMappingView(views) def get_members_view_for_guild( self, guild_id: snowflakes.Snowflakeish, / @@ -1031,7 +1018,7 @@ def get_presences_view( for guild_id, guild_record in self._guild_entries.items() if guild_record.presences } - return cache_utility.Cache3DMappingView(views) + return cache_utility.CacheMappingView(views) def get_presences_view_for_guild( self, guild: snowflakes.SnowflakeishOr[guilds.PartialGuild], / @@ -1056,11 +1043,11 @@ def set_presence(self, presence: presences.MemberPresence, /) -> None: continue if emoji.id in self._unknown_custom_emoji_entries: - self._unknown_custom_emoji_entries[emoji.id].object = copy.copy(emoji) + self._unknown_custom_emoji_entries[emoji.id].object = emoji emoji_data = self._unknown_custom_emoji_entries[emoji.id] else: - emoji_data = cache_utility.RefCell(copy.copy(emoji)) + emoji_data = cache_utility.RefCell(emoji) self._unknown_custom_emoji_entries[emoji.id] = emoji_data self._increment_ref_count(emoji_data) @@ -1137,8 +1124,7 @@ def get_role(self, role: snowflakes.SnowflakeishOr[guilds.PartialRole], /) -> ty if not self._is_cache_enabled_for(config.CacheComponents.ROLES): return None - role = self._role_entries.get(snowflakes.Snowflake(role)) - return copy.copy(role) if role else None + return self._role_entries.get(snowflakes.Snowflake(role)) def get_roles_view(self) -> cache.CacheView[snowflakes.Snowflake, guilds.Role]: if not self._is_cache_enabled_for(config.CacheComponents.ROLES): @@ -1195,7 +1181,7 @@ def _garbage_collect_user( def get_user(self, user: snowflakes.SnowflakeishOr[users.PartialUser], /) -> typing.Optional[users.User]: user = self._user_entries.get(snowflakes.Snowflake(user)) - return user.copy() if user else None + return user.object if user else None def get_users_view(self) -> cache.CacheView[snowflakes.Snowflake, users.User]: if not self._user_entries: @@ -1209,10 +1195,10 @@ def get_users_view(self) -> cache.CacheView[snowflakes.Snowflake, users.User]: def _set_user(self, user: users.User, /) -> cache_utility.RefCell[users.User]: try: - self._user_entries[user.id].object = copy.copy(user) + self._user_entries[user.id].object = user cell = self._user_entries[user.id] except KeyError: - cell = cache_utility.RefCell(copy.copy(user)) + cell = cache_utility.RefCell(user) self._user_entries[user.id] = cell return cell @@ -1337,7 +1323,7 @@ def get_voice_states_view( for guild_id, guild_record in self._guild_entries.items() if guild_record.voice_states } - return cache_utility.Cache3DMappingView(views) + return cache_utility.CacheMappingView(views) def get_voice_states_view_for_channel( self, diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index d1828b0da1..fe558363de 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -54,7 +54,6 @@ from hikari import voices as voice_models from hikari import webhooks as webhook_models from hikari.api import entity_factory -from hikari.internal import attr_extensions from hikari.internal import data_binding from hikari.internal import time @@ -84,8 +83,7 @@ def _deserialize_max_age(seconds: int) -> typing.Optional[datetime.timedelta]: return datetime.timedelta(seconds=seconds) if seconds > 0 else None -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=False, weakref_slot=False) class _GuildChannelFields: id: snowflakes.Snowflake = attr.field() name: typing.Optional[str] = attr.field() @@ -97,8 +95,7 @@ class _GuildChannelFields: parent_id: typing.Optional[snowflakes.Snowflake] = attr.field() -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=False, weakref_slot=False) class _IntegrationFields: id: snowflakes.Snowflake = attr.field() name: str = attr.field() @@ -106,13 +103,12 @@ class _IntegrationFields: account: guild_models.IntegrationAccount = attr.field() -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=False, weakref_slot=False) class _GuildFields: id: snowflakes.Snowflake = attr.field() name: str = attr.field() icon_hash: str = attr.field() - features: typing.Sequence[guild_models.GuildFeatureish] = attr.field() + features: typing.Tuple[guild_models.GuildFeatureish, ...] = attr.field() splash_hash: typing.Optional[str] = attr.field() discovery_splash_hash: typing.Optional[str] = attr.field() owner_id: snowflakes.Snowflake = attr.field() @@ -139,8 +135,7 @@ class _GuildFields: is_nsfw: bool = attr.field() -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=False, weakref_slot=False) class _InviteFields: code: str = attr.field() guild: typing.Optional[invite_models.InviteGuild] = attr.field() @@ -154,8 +149,7 @@ class _InviteFields: approximate_member_count: typing.Optional[int] = attr.field() -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=False, weakref_slot=False) class _UserFields: id: snowflakes.Snowflake = attr.field() discriminator: str = attr.field() @@ -249,9 +243,9 @@ def __init__(self, app: traits.RESTAware) -> None: def deserialize_own_connection(self, payload: data_binding.JSONObject) -> application_models.OwnConnection: if (integration_payloads := payload.get("integrations")) is not None: - integrations = [self.deserialize_partial_integration(integration) for integration in integration_payloads] + integrations = tuple(self.deserialize_partial_integration(i) for i in integration_payloads) else: - integrations = [] + integrations = () return application_models.OwnConnection( id=payload["id"], @@ -271,7 +265,7 @@ def deserialize_own_guild(self, payload: data_binding.JSONObject) -> application id=snowflakes.Snowflake(payload["id"]), name=payload["name"], icon_hash=payload["icon"], - features=[guild_models.GuildFeature(feature) for feature in payload["features"]], + features=tuple(guild_models.GuildFeature(feature) for feature in payload["features"]), is_owner=bool(payload["owner"]), my_permissions=permission_models.Permissions(int(payload["permissions"])), ) @@ -283,7 +277,7 @@ def deserialize_application(self, payload: data_binding.JSONObject) -> applicati for member_payload in team_payload["members"]: team_member = application_models.TeamMember( membership_state=application_models.TeamMembershipState(member_payload["membership_state"]), - permissions=member_payload["permissions"], + permissions=tuple(member_payload["permissions"]), team_id=snowflakes.Snowflake(member_payload["team_id"]), user=self.deserialize_user(member_payload["user"]), ) @@ -308,7 +302,7 @@ def deserialize_application(self, payload: data_binding.JSONObject) -> applicati is_bot_public=payload.get("bot_public"), is_bot_code_grant_required=payload.get("bot_require_code_grant"), owner=self.deserialize_user(payload["owner"]), - rpc_origins=payload["rpc_origins"] if "rpc_origins" in payload else None, + rpc_origins=tuple(payload["rpc_origins"]) if "rpc_origins" in payload else None, summary=payload["summary"] or None, public_key=bytes.fromhex(payload["verify_key"]) if "verify_key" in payload else None, icon_hash=payload.get("icon"), @@ -341,7 +335,7 @@ def deserialize_authorization_information( return application_models.AuthorizationInformation( application=application, - scopes=[application_models.OAuth2Scope(scope) for scope in payload["scopes"]], + scopes=tuple(application_models.OAuth2Scope(scope) for scope in payload["scopes"]), expires_at=time.iso8601_datetime_string_to_datetime(payload["expires"]), user=self.deserialize_user(payload["user"]) if "user" in payload else None, ) @@ -351,7 +345,7 @@ def deserialize_partial_token(self, payload: data_binding.JSONObject) -> applica access_token=payload["access_token"], token_type=application_models.TokenType(payload["token_type"]), expires_in=datetime.timedelta(seconds=int(payload["expires_in"])), - scopes=[application_models.OAuth2Scope(scope) for scope in payload["scope"].split(" ")], + scopes=tuple(application_models.OAuth2Scope(scope) for scope in payload["scope"].split(" ")), ) def deserialize_authorization_token( @@ -361,7 +355,7 @@ def deserialize_authorization_token( access_token=payload["access_token"], token_type=application_models.TokenType(payload["token_type"]), expires_in=datetime.timedelta(seconds=int(payload["expires_in"])), - scopes=[application_models.OAuth2Scope(scope) for scope in payload["scope"].split(" ")], + scopes=tuple(application_models.OAuth2Scope(scope) for scope in payload["scope"].split(" ")), refresh_token=payload["refresh_token"], webhook=self.deserialize_webhook(payload["webhook"]) if "webhook" in payload else None, guild=self.deserialize_rest_guild(payload["guild"]) if "guild" in payload else None, @@ -372,7 +366,7 @@ def deserialize_implicit_token(self, query: data_binding.Query) -> application_m access_token=query["access_token"], token_type=application_models.TokenType(query["token_type"]), expires_in=datetime.timedelta(seconds=int(query["expires_in"])), - scopes=[application_models.OAuth2Scope(scope) for scope in query["scope"].split(" ")], + scopes=tuple(application_models.OAuth2Scope(scope) for scope in query["scope"].split(" ")), state=query.get("state"), ) @@ -452,28 +446,30 @@ def _deserialize_member_move_entry_info( app=self._app, channel_id=snowflakes.Snowflake(payload["channel_id"]), count=int(payload["count"]) ) + def _deserialize_change(self, payload: data_binding.JSONObject) -> audit_log_models.AuditLogChange: + key: typing.Union[audit_log_models.AuditLogChangeKey, str] = audit_log_models.AuditLogChangeKey(payload["key"]) + + new_value = payload.get("new_value") + old_value = payload.get("old_value") + if value_converter := self._audit_log_entry_converters.get(key): + new_value = value_converter(new_value) if new_value is not None else None + old_value = value_converter(old_value) if old_value is not None else None + + elif not isinstance(key, audit_log_models.AuditLogChangeKey): + _LOGGER.debug("Unknown audit log change key found %r", key) + + return audit_log_models.AuditLogChange(key=key, new_value=new_value, old_value=old_value) + def deserialize_audit_log(self, payload: data_binding.JSONObject) -> audit_log_models.AuditLog: entries = {} for entry_payload in payload["audit_log_entries"]: entry_id = snowflakes.Snowflake(entry_payload["id"]) - changes = [] - if (change_payloads := entry_payload.get("changes")) is not None: - for change_payload in change_payloads: - key: typing.Union[audit_log_models.AuditLogChangeKey, str] = audit_log_models.AuditLogChangeKey( - change_payload["key"] - ) + if change_payloads := entry_payload.get("changes"): + changes = tuple(self._deserialize_change(change_payload) for change_payload in change_payloads) - new_value = change_payload.get("new_value") - old_value = change_payload.get("old_value") - if value_converter := self._audit_log_entry_converters.get(key): - new_value = value_converter(new_value) if new_value is not None else None - old_value = value_converter(old_value) if old_value is not None else None - - elif not isinstance(key, audit_log_models.AuditLogChangeKey): - _LOGGER.debug("Unknown audit log change key found %r", key) - - changes.append(audit_log_models.AuditLogChange(key=key, new_value=new_value, old_value=old_value)) + else: + changes = () target_id: typing.Optional[snowflakes.Snowflake] = None if (raw_target_id := entry_payload["target_id"]) is not None: @@ -803,7 +799,7 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em url = payload.get("url") color = color_models.Color(payload["color"]) if "color" in payload else None timestamp = time.iso8601_datetime_string_to_datetime(payload["timestamp"]) if "timestamp" in payload else None - fields: typing.Optional[typing.MutableSequence[embed_models.EmbedField]] = None + fields: typing.Optional[typing.Tuple[embed_models.EmbedField, ...]] = None image: typing.Optional[embed_models.EmbedImage[files.AsyncReader]] = None if (image_payload := payload.get("image")) and "url" in image_payload: @@ -869,14 +865,14 @@ def deserialize_embed(self, payload: data_binding.JSONObject) -> embed_models.Em footer = embed_models.EmbedFooter(text=footer_payload.get("text"), icon=icon) if fields_array := payload.get("fields"): - fields = [] - for field_payload in fields_array: - field = embed_models.EmbedField( + fields = tuple( + embed_models.EmbedField( name=field_payload["name"], value=field_payload["value"], inline=field_payload.get("inline", False), ) - fields.append(field) + for field_payload in fields_array + ) return embed_models.Embed.from_received_embed( title=title, @@ -1014,7 +1010,7 @@ def deserialize_custom_emoji(self, payload: data_binding.JSONObject) -> emoji_mo def deserialize_known_custom_emoji( self, payload: data_binding.JSONObject, *, guild_id: snowflakes.Snowflake ) -> emoji_models.KnownCustomEmoji: - role_ids = [snowflakes.Snowflake(role_id) for role_id in payload["roles"]] if "roles" in payload else [] + role_ids = tuple(snowflakes.Snowflake(role_id) for role_id in payload["roles"]) if "roles" in payload else () user: typing.Optional[user_models.User] = None if (raw_user := payload.get("user")) is not None: @@ -1072,30 +1068,27 @@ def deserialize_guild_widget(self, payload: data_binding.JSONObject) -> guild_mo return guild_models.GuildWidget(app=self._app, channel_id=channel_id, is_enabled=payload["enabled"]) - def deserialize_welcome_screen(self, payload: data_binding.JSONObject) -> guild_models.WelcomeScreen: - channels: typing.List[guild_models.WelcomeChannel] = [] - - for channel_payload in payload["welcome_channels"]: - emoji_id = channel_payload["emoji_id"] - emoji_name = channel_payload["emoji_name"] - - emoji: typing.Optional[emoji_models.Emoji] = None - if emoji_name is not None: - if emoji_id is not None: - emoji = emoji_models.CustomEmoji( - id=snowflakes.Snowflake(emoji_id), name=emoji_name, is_animated=None - ) - else: - emoji = emoji_models.UnicodeEmoji(emoji_name) - - channels.append( - guild_models.WelcomeChannel( - channel_id=snowflakes.Snowflake(channel_payload["channel_id"]), - description=channel_payload["description"], - emoji=emoji, - ) - ) + def _deserialize_welcome_channel(self, payload: data_binding.JSONObject) -> guild_models.WelcomeChannel: + emoji_id = payload["emoji_id"] + emoji_name = payload["emoji_name"] + + emoji: typing.Optional[emoji_models.Emoji] = None + if emoji_name is not None: + if emoji_id is not None: + emoji = emoji_models.CustomEmoji(id=snowflakes.Snowflake(emoji_id), name=emoji_name, is_animated=None) + else: + emoji = emoji_models.UnicodeEmoji(emoji_name) + + return guild_models.WelcomeChannel( + channel_id=snowflakes.Snowflake(payload["channel_id"]), + description=payload["description"], + emoji=emoji, + ) + def deserialize_welcome_screen(self, payload: data_binding.JSONObject) -> guild_models.WelcomeScreen: + channels = tuple( + self._deserialize_welcome_channel(channel_payload) for channel_payload in payload["welcome_channels"] + ) return guild_models.WelcomeScreen(description=payload["description"], channels=channels) def serialize_welcome_channel(self, welcome_channel: guild_models.WelcomeChannel) -> data_binding.JSONObject: @@ -1142,7 +1135,7 @@ def deserialize_member( return guild_models.Member( user=user, guild_id=guild_id, - role_ids=role_ids, + role_ids=tuple(role_ids), joined_at=joined_at, nickname=payload.get("nick"), premium_since=premium_since, @@ -1281,7 +1274,7 @@ def deserialize_guild_preview(self, payload: data_binding.JSONObject) -> guild_m id=guild_id, name=payload["name"], icon_hash=payload["icon"], - features=[guild_models.GuildFeature(feature) for feature in payload["features"]], + features=tuple(guild_models.GuildFeature(feature) for feature in payload["features"]), splash_hash=payload["splash"], discovery_splash_hash=payload["discovery_splash"], emojis=emojis, @@ -1310,7 +1303,7 @@ def _set_guild_attributes(self, payload: data_binding.JSONObject) -> _GuildField id=snowflakes.Snowflake(payload["id"]), name=payload["name"], icon_hash=payload["icon"], - features=[guild_models.GuildFeature(feature) for feature in payload["features"]], + features=tuple(guild_models.GuildFeature(feature) for feature in payload["features"]), splash_hash=payload["splash"], # This is documented as always being present, but we have found old guilds where this is # not present. Quicker to just assume the documentation is wrong at this point than try @@ -1516,7 +1509,7 @@ def _set_invite_attributes(self, payload: data_binding.JSONObject) -> _InviteFie app=self._app, id=snowflakes.Snowflake(guild_payload["id"]), name=guild_payload["name"], - features=[guild_models.GuildFeature(feature) for feature in guild_payload["features"]], + features=tuple(guild_models.GuildFeature(feature) for feature in guild_payload["features"]), icon_hash=guild_payload["icon"], splash_hash=guild_payload["splash"], banner_hash=guild_payload["banner"], @@ -1675,7 +1668,7 @@ def _deserialize_sticker(self, payload: data_binding.JSONObject) -> message_mode pack_id=snowflakes.Snowflake(payload["pack_id"]), name=payload["name"], description=payload["description"], - tags=[tag.strip() for tag in payload["tags"].split(",")] if "tags" in payload else [], + tags=tuple(tag.strip() for tag in payload["tags"].split(",")) if "tags" in payload else (), asset_hash=payload["asset"], format_type=message_models.StickerFormatType(payload["format_type"]), ) @@ -1706,17 +1699,17 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long else: edited_timestamp = None - attachments: undefined.UndefinedOr[typing.MutableSequence[message_models.Attachment]] = undefined.UNDEFINED - if "attachments" in payload: - attachments = [self._deserialize_message_attachment(attachment) for attachment in payload["attachments"]] + attachments: undefined.UndefinedOr[typing.Tuple[message_models.Attachment, ...]] = undefined.UNDEFINED + if raw_attachments := payload.get("attachments"): + attachments = tuple(self._deserialize_message_attachment(attachment) for attachment in raw_attachments) - embeds: undefined.UndefinedOr[typing.Sequence[embed_models.Embed]] = undefined.UNDEFINED - if "embeds" in payload: - embeds = [self.deserialize_embed(embed) for embed in payload["embeds"]] + embeds: undefined.UndefinedOr[typing.Tuple[embed_models.Embed, ...]] = undefined.UNDEFINED + if raw_embeds := payload.get("embeds"): + embeds = tuple(self.deserialize_embed(embed) for embed in raw_embeds) - reactions: undefined.UndefinedOr[typing.MutableSequence[message_models.Reaction]] = undefined.UNDEFINED - if "reactions" in payload: - reactions = [self._deserialize_message_reaction(reaction) for reaction in payload["reactions"]] + reactions: undefined.UndefinedOr[typing.Tuple[message_models.Reaction, ...]] = undefined.UNDEFINED + if raw_reactions := payload.get("reactions"): + reactions = tuple(self._deserialize_message_reaction(reaction) for reaction in raw_reactions) activity: undefined.UndefinedOr[message_models.MessageActivity] = undefined.UNDEFINED if "activity" in payload: @@ -1737,9 +1730,9 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long else: referenced_message = None - stickers: undefined.UndefinedOr[typing.Sequence[message_models.Sticker]] = undefined.UNDEFINED - if "stickers" in payload: - stickers = [self._deserialize_sticker(sticker) for sticker in payload["stickers"]] + stickers: undefined.UndefinedOr[typing.Tuple[message_models.Sticker, ...]] = undefined.UNDEFINED + if raw_stickers := payload.get("stickers"): + stickers = tuple(self._deserialize_sticker(sticker) for sticker in raw_stickers) message = message_models.PartialMessage( app=self._app, @@ -1779,9 +1772,9 @@ def deserialize_partial_message( # noqa CFQ001 - Function too long if raw_users := payload.get("mentions"): users = {u.id: u for u in map(self.deserialize_user, raw_users)} - role_ids: undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]] = undefined.UNDEFINED + role_ids: undefined.UndefinedOr[typing.Tuple[snowflakes.Snowflake, ...]] = undefined.UNDEFINED if raw_role_ids := payload.get("mention_roles"): - role_ids = [snowflakes.Snowflake(i) for i in raw_role_ids] + role_ids = tuple(snowflakes.Snowflake(i) for i in raw_role_ids) everyone = payload.get("mention_everyone", undefined.UNDEFINED) @@ -1810,15 +1803,15 @@ def deserialize_message( # noqa CFQ001 - Function too long if (raw_edited_timestamp := payload["edited_timestamp"]) is not None: edited_timestamp = time.iso8601_datetime_string_to_datetime(raw_edited_timestamp) - attachments = [self._deserialize_message_attachment(attachment) for attachment in payload["attachments"]] + attachments = tuple(self._deserialize_message_attachment(attachment) for attachment in payload["attachments"]) - embeds = [self.deserialize_embed(embed) for embed in payload["embeds"]] + embeds = tuple(self.deserialize_embed(embed) for embed in payload["embeds"]) if "reactions" in payload: - reactions = [self._deserialize_message_reaction(reaction) for reaction in payload["reactions"]] + reactions = tuple(self._deserialize_message_reaction(reaction) for reaction in payload["reactions"]) else: - reactions = [] + reactions = () activity: typing.Optional[message_models.MessageActivity] = None if "activity" in payload: @@ -1840,10 +1833,10 @@ def deserialize_message( # noqa CFQ001 - Function too long application = self._deserialize_message_application(payload["application"]) if "stickers" in payload: - stickers = [self._deserialize_sticker(sticker) for sticker in payload["stickers"]] + stickers = tuple(self._deserialize_sticker(sticker) for sticker in payload["stickers"]) else: - stickers = [] + stickers = () message = message_models.Message( app=self._app, @@ -1881,9 +1874,11 @@ def deserialize_message( # noqa CFQ001 - Function too long if raw_users := payload.get("mentions"): users = {u.id: u for u in map(self.deserialize_user, raw_users)} - role_ids: typing.Sequence[snowflakes.Snowflake] = [] if raw_role_ids := payload.get("mention_roles"): - role_ids = [snowflakes.Snowflake(i) for i in raw_role_ids] + role_ids = tuple(snowflakes.Snowflake(i) for i in raw_role_ids) + + else: + role_ids = () everyone = payload.get("mention_everyone", False) @@ -1901,106 +1896,95 @@ def deserialize_message( # noqa CFQ001 - Function too long # PRESENCE MODELS # ################### - def deserialize_member_presence( # noqa: CFQ001 - Max function length - self, - payload: data_binding.JSONObject, - *, - guild_id: undefined.UndefinedOr[snowflakes.Snowflake] = undefined.UNDEFINED, - ) -> presence_models.MemberPresence: - activities = [] - for activity_payload in payload["activities"]: - timestamps: typing.Optional[presence_models.ActivityTimestamps] = None - if "timestamps" in activity_payload: - timestamps_payload = activity_payload["timestamps"] - start = ( - time.unix_epoch_to_datetime(timestamps_payload["start"]) if "start" in timestamps_payload else None - ) - end = time.unix_epoch_to_datetime(timestamps_payload["end"]) if "end" in timestamps_payload else None - timestamps = presence_models.ActivityTimestamps(start=start, end=end) + def _deserialize_rich_activity(self, payload: data_binding.JSONObject) -> presence_models.RichActivity: + timestamps: typing.Optional[presence_models.ActivityTimestamps] = None + if timestamps_payload := payload.get("timestamps"): - application_id = ( - snowflakes.Snowflake(activity_payload["application_id"]) - if "application_id" in activity_payload - else None - ) + start = time.unix_epoch_to_datetime(timestamps_payload["start"]) if "start" in timestamps_payload else None + end = time.unix_epoch_to_datetime(timestamps_payload["end"]) if "end" in timestamps_payload else None + timestamps = presence_models.ActivityTimestamps(start=start, end=end) - party: typing.Optional[presence_models.ActivityParty] = None - if "party" in activity_payload: - party_payload = activity_payload["party"] - - current_size: typing.Optional[int] - max_size: typing.Optional[int] - if "size" in party_payload: - raw_current_size, raw_max_size = party_payload["size"] - current_size = int(raw_current_size) - max_size = int(raw_max_size) - else: - current_size = max_size = None - - party = presence_models.ActivityParty( - id=party_payload.get("id"), current_size=current_size, max_size=max_size - ) + application_id: typing.Optional[snowflakes.Snowflake] = None + if raw_application_id := payload.get("application_id"): + application_id = snowflakes.Snowflake(raw_application_id) - assets: typing.Optional[presence_models.ActivityAssets] = None - if "assets" in activity_payload: - assets_payload = activity_payload["assets"] - assets = presence_models.ActivityAssets( - large_image=assets_payload.get("large_image"), - large_text=assets_payload.get("large_text"), - small_image=assets_payload.get("small_image"), - small_text=assets_payload.get("small_text"), - ) + party: typing.Optional[presence_models.ActivityParty] = None + if party_payload := payload.get("party"): + current_size: typing.Optional[int] + max_size: typing.Optional[int] + if "size" in party_payload: + raw_current_size, raw_max_size = party_payload["size"] + current_size = int(raw_current_size) + max_size = int(raw_max_size) + else: + current_size = max_size = None - secrets: typing.Optional[presence_models.ActivitySecret] = None - if "secrets" in activity_payload: - secrets_payload = activity_payload["secrets"] - secrets = presence_models.ActivitySecret( - join=secrets_payload.get("join"), - spectate=secrets_payload.get("spectate"), - match=secrets_payload.get("match"), - ) + party = presence_models.ActivityParty( + id=party_payload.get("id"), current_size=current_size, max_size=max_size + ) - emoji: typing.Optional[emoji_models.Emoji] = None - raw_emoji = activity_payload.get("emoji") - if raw_emoji is not None: - emoji = self.deserialize_emoji(raw_emoji) - - activity = presence_models.RichActivity( - name=activity_payload["name"], - # RichActivity's generated init already declares a converter for the "type" field - type=activity_payload["type"], - url=activity_payload.get("url"), - created_at=time.unix_epoch_to_datetime(activity_payload["created_at"]), - timestamps=timestamps, - application_id=application_id, - details=activity_payload.get("details"), - state=activity_payload.get("state"), - emoji=emoji, - party=party, - assets=assets, - secrets=secrets, - is_instance=activity_payload.get("instance"), # TODO: can we safely default this to False? - flags=presence_models.ActivityFlag(activity_payload["flags"]) if "flags" in activity_payload else None, - buttons=activity_payload.get("buttons") or [], + assets: typing.Optional[presence_models.ActivityAssets] = None + if assets_payload := payload.get("assets"): + assets = presence_models.ActivityAssets( + large_image=assets_payload.get("large_image"), + large_text=assets_payload.get("large_text"), + small_image=assets_payload.get("small_image"), + small_text=assets_payload.get("small_text"), ) - activities.append(activity) - client_status_payload = payload["client_status"] - desktop = ( - presence_models.Status(client_status_payload["desktop"]) - if "desktop" in client_status_payload - else presence_models.Status.OFFLINE - ) - mobile = ( - presence_models.Status(client_status_payload["mobile"]) - if "mobile" in client_status_payload - else presence_models.Status.OFFLINE - ) - web = ( - presence_models.Status(client_status_payload["web"]) - if "web" in client_status_payload - else presence_models.Status.OFFLINE + secrets: typing.Optional[presence_models.ActivitySecret] = None + if secrets_payload := payload.get("secrets"): + secrets = presence_models.ActivitySecret( + join=secrets_payload.get("join"), + spectate=secrets_payload.get("spectate"), + match=secrets_payload.get("match"), + ) + + emoji: typing.Optional[emoji_models.Emoji] = None + if raw_emoji := payload.get("emoji"): + emoji = self.deserialize_emoji(raw_emoji) + + return presence_models.RichActivity( + name=payload["name"], + # RichActivity's generated init already declares a converter for the "type" field + type=payload["type"], + url=payload.get("url"), + created_at=time.unix_epoch_to_datetime(payload["created_at"]), + timestamps=timestamps, + application_id=application_id, + details=payload.get("details"), + state=payload.get("state"), + emoji=emoji, + party=party, + assets=assets, + secrets=secrets, + is_instance=payload.get("instance"), # TODO: can we safely default this to False? + flags=presence_models.ActivityFlag(payload["flags"]) if "flags" in payload else None, + buttons=tuple(payload["buttons"]) if "buttons" in payload else (), + ) + + def deserialize_member_presence( + self, + payload: data_binding.JSONObject, + *, + guild_id: undefined.UndefinedOr[snowflakes.Snowflake] = undefined.UNDEFINED, + ) -> presence_models.MemberPresence: + activities = tuple( + self._deserialize_rich_activity(activity_payload) for activity_payload in payload["activities"] ) + client_status_payload = payload["client_status"] + desktop = presence_models.Status.OFFLINE + if raw_desktop := client_status_payload.get("desktop"): + desktop = presence_models.Status(raw_desktop) + + mobile = presence_models.Status.OFFLINE + if raw_mobile := client_status_payload.get("mobile"): + mobile = presence_models.Status(raw_mobile) + + web = presence_models.Status.OFFLINE + if raw_web := client_status_payload.get("web"): + web = presence_models.Status(raw_web) + client_status = presence_models.ClientStatus(desktop=desktop, mobile=mobile, web=web) return presence_models.MemberPresence( @@ -2192,16 +2176,14 @@ def deserialize_webhook(self, payload: data_binding.JSONObject) -> webhook_model application_id = snowflakes.Snowflake(raw_application_id) source_channel: typing.Optional[channel_models.PartialChannel] = None - if "source_channel" in payload: - raw_source_channel = payload["source_channel"] + if raw_source_channel := payload.get("source_channel"): # In this case the channel type isn't provided as we can safely # assume it's a news channel. raw_source_channel["type"] = channel_models.ChannelType.GUILD_NEWS source_channel = self.deserialize_partial_channel(raw_source_channel) source_guild: typing.Optional[guild_models.PartialGuild] = None - if "source_guild" in payload: - source_guild_payload = payload["source_guild"] + if source_guild_payload := payload.get("source_guild"): source_guild = guild_models.PartialGuild( app=self._app, id=snowflakes.Snowflake(source_guild_payload["id"]), diff --git a/hikari/impl/event_factory.py b/hikari/impl/event_factory.py index a0cdd76a2e..64115aed40 100644 --- a/hikari/impl/event_factory.py +++ b/hikari/impl/event_factory.py @@ -275,10 +275,10 @@ def deserialize_guild_emojis_update_event( old_emojis: typing.Optional[typing.Sequence[emojis_models.KnownCustomEmoji]], ) -> guild_events.EmojisUpdateEvent: guild_id = snowflakes.Snowflake(payload["guild_id"]) - emojis = [ + emojis = tuple( self._app.entity_factory.deserialize_known_custom_emoji(emoji, guild_id=guild_id) for emoji in payload["emojis"] - ] + ) return guild_events.EmojisUpdateEvent( app=self._app, shard=shard, guild_id=guild_id, emojis=emojis, old_emojis=old_emojis ) @@ -670,7 +670,7 @@ def deserialize_guild_member_chunk_event( for m in payload["members"] } # Note, these IDs may be returned as ints or strings based on whether they're over a certain value. - not_found = [snowflakes.Snowflake(sn) for sn in payload["not_found"]] if "not_found" in payload else [] + not_found = tuple(snowflakes.Snowflake(sn) for sn in payload["not_found"]) if "not_found" in payload else () if presence_payloads := payload.get("presences"): presences = { diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index e06d936f11..423d26377e 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -523,7 +523,7 @@ class RESTClientImpl(rest_api.RESTClient): global_rate_limit: rate_limits.ManualRateLimiter """Global ratelimiter.""" - @attr.define(auto_exc=True, repr=False, weakref_slot=False) + @attr.frozen(auto_exc=True, repr=False, weakref_slot=False) class _RetryRequest(RuntimeError): ... diff --git a/hikari/impl/special_endpoints.py b/hikari/impl/special_endpoints.py index d0c8bd1eef..6d753e4efa 100644 --- a/hikari/impl/special_endpoints.py +++ b/hikari/impl/special_endpoints.py @@ -41,7 +41,6 @@ from hikari import snowflakes from hikari import undefined from hikari.api import special_endpoints -from hikari.internal import attr_extensions from hikari.internal import data_binding from hikari.internal import routes from hikari.internal import time @@ -141,7 +140,6 @@ async def _keep_typing(self) -> None: # As a note, slotting allows us to override the settable properties while staying within the interface's spec. -@attr_extensions.with_copy @attr.define(kw_only=True, weakref_slot=False) class GuildBuilder(special_endpoints.GuildBuilder): """Result type of `hikari.api.rest.RESTClient.guild_builder`. @@ -217,14 +215,12 @@ class GuildBuilder(special_endpoints.GuildBuilder): """ # Required arguments. - _entity_factory: entity_factory_.EntityFactory = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) - _executor: typing.Optional[concurrent.futures.Executor] = attr.field( - metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + _entity_factory: entity_factory_.EntityFactory = attr.field() + _executor: typing.Optional[concurrent.futures.Executor] = attr.field() _name: str = attr.field() _request_call: typing.Callable[ ..., typing.Coroutine[None, None, typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]] - ] = attr.field(metadata={attr_extensions.SKIP_DEEP_COPY: True}) + ] = attr.field() # Optional arguments. default_message_notifications: undefined.UndefinedOr[guilds.GuildMessageNotificationsLevel] = attr.field( diff --git a/hikari/internal/attr_extensions.py b/hikari/internal/attr_extensions.py deleted file mode 100644 index ea7031360e..0000000000 --- a/hikari/internal/attr_extensions.py +++ /dev/null @@ -1,260 +0,0 @@ -# -*- coding: utf-8 -*- -# cython: language_level=3 -# Copyright (c) 2020 Nekokatt -# Copyright (c) 2021 davfsa -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -"""Utility for extending and optimising the usage of `attr` models.""" -from __future__ import annotations - -__all__: typing.List[str] = [ - "with_copy", - "copy_attrs", - "deep_copy_attrs", - "invalidate_deep_copy_cache", - "invalidate_shallow_copy_cache", -] - -import copy as std_copy -import logging -import typing - -import attr - -ModelT = typing.TypeVar("ModelT") -SKIP_DEEP_COPY: typing.Final[str] = "skip_deep_copy" - -_DEEP_COPIERS: typing.MutableMapping[ - typing.Any, typing.Callable[[typing.Any, typing.MutableMapping[int, typing.Any]], None] -] = {} -_SHALLOW_COPIERS: typing.MutableMapping[typing.Any, typing.Callable[[typing.Any], typing.Any]] = {} -_LOGGER = logging.getLogger("hikari.models") - - -def invalidate_shallow_copy_cache() -> None: - """Remove all the globally cached copy functions.""" - _LOGGER.debug("Invalidating attr extensions shallow copy cache") - _SHALLOW_COPIERS.clear() - - -def invalidate_deep_copy_cache() -> None: - """Remove all the globally cached generated deep copy functions.""" - _LOGGER.debug("Invalidating attr extensions deep copy cache") - _DEEP_COPIERS.clear() - - -def get_fields_definition( - cls: typing.Type[ModelT], -) -> typing.Tuple[ - typing.Sequence[typing.Tuple[attr.Attribute[typing.Any], str]], typing.Sequence[attr.Attribute[typing.Any]] -]: - """Get a sequence of init key-words to their relative attribute. - - Parameters - ---------- - cls : typing.Type[ModelT] - The attrs class to get the fields definition for. - - Returns - ------- - typing.Sequence[typing.Tuple[builtins.str, builtins.str]] - A sequence of tuples of string attribute names to string key-word names. - """ - init_results = [] - non_init_results = [] - - for field in attr.fields(cls): - if field.init: - key_word = field.name[1:] if field.name.startswith("_") else field.name - init_results.append((field, key_word)) - else: - non_init_results.append(field) - - return init_results, non_init_results - - -# TODO: can we get if the init wasn't generated for the class? -def generate_shallow_copier(cls: typing.Type[ModelT]) -> typing.Callable[[ModelT], ModelT]: - """Generate a function for shallow copying an attrs model with `init` enabled. - - Parameters - ---------- - cls : typing.Type[ModelT] - The attrs class to generate a shallow copying function for. - - Returns - ------- - typing.Callable[[ModelT], ModelT] - The generated shallow copying function. - """ - # This import is delayed to avoid a circular import error on startup - from hikari.internal import ux - - kwargs, setters = get_fields_definition(cls) - kwargs = ",".join(f"{kwarg}=m.{attribute.name}" for attribute, kwarg in kwargs) - setters = ";".join(f"r.{attribute.name}=m.{attribute.name}" for attribute in setters) + ";" if setters else "" - code = f"def copy(m):r=cls({kwargs});{setters}return r" - globals_ = {"cls": cls} - _LOGGER.log(ux.TRACE, "generating shallow copy function for %r: %r", cls, code) - exec(code, globals_) # noqa: S102 - Use of exec detected. - return typing.cast("typing.Callable[[ModelT], ModelT]", globals_["copy"]) - - -def get_or_generate_shallow_copier(cls: typing.Type[ModelT]) -> typing.Callable[[ModelT], ModelT]: - """Get a cached shallow copying function for a an attrs class or generate it. - - Parameters - ---------- - cls : typing.Type[ModelT] - The class to get or generate and cache a shallow copying function for. - - Returns - ------- - typing.Callable[[ModelT], ModelT] - The cached or generated shallow copying function. - """ - try: - return _SHALLOW_COPIERS[cls] - except KeyError: - copier = generate_shallow_copier(cls) - _SHALLOW_COPIERS[cls] = copier - return copier - - -def copy_attrs(model: ModelT) -> ModelT: - """Shallow copy an attrs model with `init` enabled. - - Parameters - ---------- - model : ModelT - The attrs model to shallow copy. - - Returns - ------- - ModelT - The new shallow copied attrs model. - """ - return get_or_generate_shallow_copier(type(model))(model) - - -def _normalize_kwargs_and_setters( - kwargs: typing.Sequence[typing.Tuple[attr.Attribute[typing.Any], str]], - setters: typing.Sequence[attr.Attribute[typing.Any]], -) -> typing.Iterable[attr.Attribute[typing.Any]]: - for attribute, _ in kwargs: - yield attribute - - yield from setters - - -def generate_deep_copier( - cls: typing.Type[ModelT], -) -> typing.Callable[[ModelT, typing.MutableMapping[int, typing.Any]], None]: - """Generate a function for deep copying an attrs model with `init` enabled. - - Parameters - ---------- - cls : typing.Type[ModelT] - The attrs class to generate a deep copying function for. - - Returns - ------- - typing.Callable[[ModelT], ModelT] - The generated deep copying function. - """ - kwargs, setters = get_fields_definition(cls) - - # Explicitly handle the case of an attrs model with no fields by returning - # an empty lambda to avoid a SyntaxError being raised. - if not kwargs and not setters: - return lambda _, __: None - - setters = ";".join( - f"m.{attribute.name}=std_copy(m.{attribute.name},memo)if(id_:=id(m.{attribute.name}))not in memo else memo[id_]" - for attribute in _normalize_kwargs_and_setters(kwargs, setters) - if not attribute.metadata.get(SKIP_DEEP_COPY) - ) - code = f"def deep_copy(m,memo):{setters}" - globals_ = {"std_copy": std_copy.deepcopy, "cls": cls} - _LOGGER.debug("generating deep copy function for %r: %r", cls, code) - exec(code, globals_) # noqa: S102 - Use of exec detected. - return typing.cast("typing.Callable[[ModelT, typing.MutableMapping[int, typing.Any]], None]", globals_["deep_copy"]) - - -def get_or_generate_deep_copier( - cls: typing.Type[ModelT], -) -> typing.Callable[[ModelT, typing.MutableMapping[int, typing.Any]], None]: - """Get a cached shallow copying function for a an attrs class or generate it. - - Parameters - ---------- - cls : typing.Type[ModelT] - The class to get or generate and cache a shallow copying function for. - - Returns - ------- - typing.Callable[[ModelT], ModelT] - The cached or generated shallow copying function. - """ - try: - return _DEEP_COPIERS[cls] - except KeyError: - copier = generate_deep_copier(cls) - _DEEP_COPIERS[cls] = copier - return copier - - -def deep_copy_attrs(model: ModelT, memo: typing.Optional[typing.MutableMapping[int, typing.Any]] = None) -> ModelT: - """Deep copy an attrs model with `init` enabled. - - Parameters - ---------- - model : ModelT - The attrs model to deep copy. - memo : typing.Optional[typing.MutableMapping[builtins.int, typing.Any]] - A memo dictionary of objects already copied during the current copying - pass, see https://docs.python.org/3/library/copy.html for more details. - - !!! note - This won't deep copy attributes where "skip_deep_copy" is set to - `builtins.True` in their metadata. - - Returns - ------- - ModelT - The new deep copied attrs model. - """ - if memo is None: - memo = {} - - new_object = std_copy.copy(model) - memo[id(model)] = new_object - get_or_generate_deep_copier(type(model))(new_object, memo) - return new_object - - -def with_copy(cls: typing.Type[ModelT]) -> typing.Type[ModelT]: - """Add a custom implementation for copying attrs models to a class. - - !!! note - This will only work if the class has an attrs generated init. - """ - cls.__copy__ = copy_attrs # type: ignore[attr-defined] - cls.__deepcopy__ = deep_copy_attrs # type: ignore[attr-defined] - return cls diff --git a/hikari/internal/cache.py b/hikari/internal/cache.py index 7e6b01564b..36c6ba6a3f 100644 --- a/hikari/internal/cache.py +++ b/hikari/internal/cache.py @@ -38,8 +38,6 @@ "VoiceStateData", "RefCell", "unwrap_ref_cell", - "copy_guild_channel", - "Cache3DMappingView", "DataT", "KeyT", "ValueT", @@ -63,7 +61,6 @@ from hikari import undefined from hikari import voices from hikari.api import cache -from hikari.internal import attr_extensions from hikari.internal import collections if typing.TYPE_CHECKING: @@ -110,10 +107,6 @@ def __init__( self._data = items self._predicate = predicate - @classmethod - def _copy(cls, value: ValueT) -> ValueT: - return copy.copy(value) - def __contains__(self, key: typing.Any) -> bool: return key in self._data and (self._predicate is None or self._predicate(self._data[key])) @@ -126,10 +119,7 @@ def __getitem__(self, key: KeyT) -> ValueT: if self._builder is not None: entry = self._builder(entry) # type: ignore[arg-type] - else: - entry = self._copy(entry) # type: ignore[arg-type] - - return entry + return entry # type: ignore[return-value] def __iter__(self) -> typing.Iterator[KeyT]: if self._predicate is None: @@ -183,7 +173,6 @@ def iterator(self) -> iterators.LazyIterator[ValueT]: return iterators.FlatLazyIterator(()) -@attr_extensions.with_copy @attr.define(repr=False, hash=False, weakref_slot=False) class GuildRecord: """An object used for storing guild specific cached information in-memory. @@ -325,7 +314,6 @@ def build_from_entity(cls: typing.Type[DataT], entity: ValueT, /) -> DataT: """ -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class InviteData(BaseData[invites.InviteWithMetadata]): """A data model for storing invite data in an in-memory cache.""" @@ -358,8 +346,8 @@ def build_entity(self, app: traits.RESTAware, /) -> invites.InviteWithMetadata: channel=None, guild=None, app=app, - inviter=self.inviter.copy() if self.inviter else None, - target_user=self.target_user.copy() if self.target_user else None, + inviter=self.inviter.object if self.inviter else None, + target_user=self.target_user.object if self.target_user else None, expires_at=self.created_at + self.max_age if self.max_age else None, ) @@ -373,10 +361,10 @@ def build_from_entity( target_user: typing.Optional[RefCell[users_.User]] = None, ) -> InviteData: if not inviter and invite.inviter: - inviter = RefCell(copy.copy(invite.inviter)) + inviter = RefCell(invite.inviter) if not target_user and invite.target_user: - target_user = RefCell(copy.copy(invite.target_user)) + target_user = RefCell(invite.target_user) return cls( code=invite.code, @@ -393,7 +381,6 @@ def build_from_entity( ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class MemberData(BaseData[guilds.Member]): """A data model for storing member data in an in-memory cache.""" @@ -422,7 +409,7 @@ def build_from_entity( is_deaf=member.is_deaf, is_mute=member.is_mute, is_pending=member.is_pending, - user=user or RefCell(copy.copy(member.user)), + user=user or RefCell(member.user), # role_ids is a special case as it may be mutable so we want to ensure it's immutable when cached. role_ids=tuple(member.role_ids), ) @@ -437,11 +424,10 @@ def build_entity(self, _: traits.RESTAware, /) -> guilds.Member: is_deaf=self.is_deaf, is_mute=self.is_mute, is_pending=self.is_pending, - user=self.user.copy(), + user=self.user.object, ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class KnownCustomEmojiData(BaseData[emojis.KnownCustomEmoji]): """A data model for storing known custom emoji data in an in-memory cache.""" @@ -465,7 +451,7 @@ def build_from_entity( user: typing.Optional[RefCell[users_.User]] = None, ) -> KnownCustomEmojiData: if not user and emoji.user: - user = RefCell(copy.copy(emoji.user)) + user = RefCell(emoji.user) return cls( id=emoji.id, @@ -491,11 +477,10 @@ def build_entity(self, app: traits.RESTAware, /) -> emojis.KnownCustomEmoji: is_managed=self.is_managed, is_available=self.is_available, app=app, - user=self.user.copy() if self.user else None, + user=self.user.object if self.user else None, ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class RichActivityData(BaseData[presences.RichActivity]): """A data model for storing rich activity data in an in-memory cache.""" @@ -528,15 +513,11 @@ def build_from_entity( pass elif isinstance(activity.emoji, emojis.CustomEmoji): - emoji = RefCell(copy.copy(activity.emoji)) + emoji = RefCell(activity.emoji) elif activity.emoji: emoji = activity.emoji.name - timestamps = copy.copy(activity.timestamps) if activity.timestamps is not None else None - party = copy.copy(activity.party) if activity.party is not None else None - assets = copy.copy(activity.assets) if activity.assets is not None else None - secrets = copy.copy(activity.secrets) if activity.secrets is not None else None return cls( name=activity.name, url=activity.url, @@ -548,17 +529,17 @@ def build_from_entity( is_instance=activity.is_instance, flags=activity.flags, emoji=emoji, - timestamps=timestamps, - party=party, - assets=assets, - secrets=secrets, + timestamps=activity.timestamps, + party=activity.party, + assets=activity.assets, + secrets=activity.secrets, buttons=tuple(activity.buttons), ) def build_entity(self, _: traits.RESTAware, /) -> presences.RichActivity: emoji: typing.Optional[emojis.Emoji] = None if isinstance(self.emoji, RefCell): - emoji = self.emoji.copy() + emoji = self.emoji.object elif self.emoji is not None: emoji = emojis.UnicodeEmoji(self.emoji) @@ -573,16 +554,15 @@ def build_entity(self, _: traits.RESTAware, /) -> presences.RichActivity: is_instance=self.is_instance, flags=self.flags, state=self.state, - timestamps=copy.copy(self.timestamps) if self.timestamps is not None else None, - party=copy.copy(self.party) if self.party is not None else None, - assets=copy.copy(self.assets) if self.assets is not None else None, - secrets=copy.copy(self.secrets) if self.secrets is not None else None, + timestamps=self.timestamps, + party=self.party, + assets=self.assets, + secrets=self.secrets, emoji=emoji, buttons=self.buttons, ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class MemberPresenceData(BaseData[presences.MemberPresence]): """A data model for storing presence data in an in-memory cache.""" @@ -602,7 +582,7 @@ def build_from_entity(cls, presence: presences.MemberPresence, /) -> MemberPrese guild_id=presence.guild_id, visible_status=presence.visible_status, activities=tuple(RichActivityData.build_from_entity(activity) for activity in presence.activities), - client_status=copy.copy(presence.client_status), + client_status=presence.client_status, ) def build_entity(self, app: traits.RESTAware, /) -> presences.MemberPresence: @@ -612,11 +592,10 @@ def build_entity(self, app: traits.RESTAware, /) -> presences.MemberPresence: visible_status=self.visible_status, app=app, activities=[activity.build_entity(app) for activity in self.activities], - client_status=copy.copy(self.client_status), + client_status=self.client_status, ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class MentionsData(BaseData[messages.Mentions]): """A model for storing message mentions data in an in-memory cache.""" @@ -635,13 +614,13 @@ def build_from_entity( users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, RefCell[users_.User]]] = undefined.UNDEFINED, ) -> MentionsData: if not users and mentions.users is not undefined.UNDEFINED: - users = {user_id: RefCell(copy.copy(user)) for user_id, user in mentions.users.items()} + users = {user_id: RefCell(user) for user_id, user in mentions.users.items()} channels: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, "channels_.PartialChannel"] ] = undefined.UNDEFINED if mentions.channels is not undefined.UNDEFINED: - channels = {channel_id: copy.copy(channel) for channel_id, channel in mentions.channels.items()} + channels = dict(mentions.channels.items()) return cls( users=users, @@ -655,13 +634,13 @@ def build_entity( ) -> messages.Mentions: users: undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, users_.User]] = undefined.UNDEFINED if self.users is not undefined.UNDEFINED: - users = {user_id: user.copy() for user_id, user in self.users.items()} + users = {user_id: user.object for user_id, user in self.users.items()} channels: undefined.UndefinedOr[ typing.Mapping[snowflakes.Snowflake, channels_.PartialChannel] ] = undefined.UNDEFINED if self.channels is not undefined.UNDEFINED: - channels = {channel_id: copy.copy(channel) for channel_id, channel in self.channels.items()} + channels = dict(self.channels) return messages.Mentions( message=message or NotImplemented, @@ -682,18 +661,19 @@ def update( self.users = users elif mention.users is not undefined.UNDEFINED: - self.users = {user_id: RefCell(copy.copy(user)) for user_id, user in mention.users.items()} + self.users = {user_id: RefCell(user) for user_id, user in mention.users.items()} if mention.role_ids is not undefined.UNDEFINED: self.role_ids = tuple(mention.role_ids) if mention.channels is not undefined.UNDEFINED: - self.channels = {channel_id: copy.copy(channel) for channel_id, channel in mention.channels.items()} + self.channels = dict(mention.channels) if mention.everyone is not undefined.UNDEFINED: self.everyone = mention.everyone +# TODO: this should be removed when embed switched to received vs builder def _copy_embed(embed: embeds_.Embed) -> embeds_.Embed: return embeds_.Embed.from_received_embed( title=embed.title, @@ -703,15 +683,14 @@ def _copy_embed(embed: embeds_.Embed) -> embeds_.Embed: timestamp=embed.timestamp, image=copy.copy(embed.image) if embed.image else None, thumbnail=copy.copy(embed.thumbnail) if embed.thumbnail else None, - video=copy.copy(embed.video) if embed.video else None, + video=embed.video, author=copy.copy(embed.author) if embed.author else None, - provider=copy.copy(embed.provider) if embed.provider else None, + provider=embed.provider, footer=copy.copy(embed.footer) if embed.footer else None, fields=list(map(copy.copy, embed.fields)), # type: ignore[arg-type] ) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class MessageData(BaseData[messages.Message]): """A model for storing message data in an in-memory cache.""" @@ -763,24 +742,24 @@ def build_from_entity( id=message.id, channel_id=message.channel_id, guild_id=message.guild_id, - author=author or RefCell(copy.copy(message.author)), + author=author or RefCell(message.author), member=member, content=message.content, timestamp=message.timestamp, edited_timestamp=message.edited_timestamp, is_tts=message.is_tts, mentions=MentionsData.build_from_entity(message.mentions, users=mention_users), - attachments=tuple(map(copy.copy, message.attachments)), + attachments=tuple(message.attachments), embeds=tuple(map(_copy_embed, message.embeds)), - reactions=tuple(map(copy.copy, message.reactions)), + reactions=tuple(message.reactions), is_pinned=message.is_pinned, webhook_id=message.webhook_id, type=message.type, - activity=copy.copy(message.activity) if message.activity else None, - application=copy.copy(message.application) if message.application else None, - message_reference=copy.copy(message.message_reference) if message.message_reference else None, - flags=copy.copy(message.flags), - stickers=tuple(map(copy.copy, message.stickers)), + activity=message.activity, + application=message.application, + message_reference=message.message_reference, + flags=message.flags, + stickers=tuple(message.stickers), nonce=message.nonce, referenced_message=referenced_message, ) @@ -798,24 +777,24 @@ def build_entity(self, app: traits.RESTAware, /) -> messages.Message: app=app, channel_id=self.channel_id, guild_id=self.guild_id, - author=self.author.copy(), + author=self.author.object, member=self.member.object.build_entity(app) if self.member else None, content=self.content, timestamp=self.timestamp, edited_timestamp=self.edited_timestamp, is_tts=self.is_tts, mentions=NotImplemented, - attachments=tuple(map(copy.copy, self.attachments)), - embeds=tuple(map(_copy_embed, self.embeds)), - reactions=tuple(map(copy.copy, self.reactions)), + attachments=tuple(self.attachments), + embeds=tuple(self.embeds), + reactions=tuple(self.reactions), is_pinned=self.is_pinned, webhook_id=self.webhook_id, type=self.type, - activity=copy.copy(self.activity) if self.activity else None, - application=copy.copy(self.application) if self.application else None, - message_reference=copy.copy(self.message_reference) if self.message_reference else None, + activity=self.activity, + application=self.application, + message_reference=self.message_reference, flags=self.flags, - stickers=tuple(map(copy.copy, self.stickers)), + stickers=tuple(self.stickers), nonce=self.nonce, referenced_message=referenced_message, ) @@ -841,15 +820,14 @@ def update( self.is_pinned = message.is_pinned if message.attachments is not undefined.UNDEFINED: - self.attachments = tuple(map(copy.copy, message.attachments)) + self.attachments = tuple(message.attachments) if message.embeds is not undefined.UNDEFINED: - self.embeds = tuple(map(_copy_embed, message.embeds)) + self.embeds = tuple(message.embeds) self.mentions.update(message.mentions, users=mention_users) -@attr_extensions.with_copy @attr.define(kw_only=True, repr=False, hash=False, weakref_slot=False) class VoiceStateData(BaseData[voices.VoiceState]): """A data model for storing voice state data in an in-memory cache.""" @@ -910,25 +888,13 @@ def build_from_entity( ) -@attr_extensions.with_copy @attr.define(repr=True, hash=False, weakref_slot=True) class Cell(typing.Generic[ValueT]): """Object used to store mutable references to a value in multiple places.""" object: ValueT = attr.field(repr=True) - def copy(self) -> ValueT: - """Get a copy of the contents of this cell. - - Returns - ------- - ValueT - The copied contents of this cell. - """ - return copy.copy(self.object) - -@attr_extensions.with_copy @attr.define(repr=False, hash=False, weakref_slot=False) class RefCell(typing.Generic[ValueT]): """Object used to track mutable references to a value in multiple places. @@ -942,16 +908,6 @@ class RefCell(typing.Generic[ValueT]): object: ValueT = attr.field(repr=True) ref_count: int = attr.field(default=0, kw_only=True) - def copy(self) -> ValueT: - """Get a copy of the contents of this cell. - - Returns - ------- - ValueT - The copied contents of this cell. - """ - return copy.copy(self.object) - def unwrap_ref_cell(cell: RefCell[ValueT]) -> ValueT: """Unwrap a `RefCell` instance to it's contents. @@ -966,27 +922,4 @@ def unwrap_ref_cell(cell: RefCell[ValueT]) -> ValueT: ValueT The reference cell's content. """ - return cell.copy() - - -def copy_guild_channel(channel: ChannelT) -> ChannelT: - """Logic for handling the copying of guild channel objects. - - This exists account for the permission overwrite objects attached to guild - channel objects which need to be copied themselves. - """ - channel = copy.copy(channel) - channel.permission_overwrites = { - sf: copy.copy(overwrite) for sf, overwrite in channel.permission_overwrites.items() - } - return channel - - -class Cache3DMappingView(CacheMappingView[snowflakes.Snowflake, cache.CacheView[KeyT, ValueT]]): - """A special case of the Mapping View which avoids copying the already immutable views contained within it.""" - - __slots__: typing.Sequence[str] = () - - @classmethod - def _copy(cls, value: cache.CacheView[KeyT, ValueT]) -> cache.CacheView[KeyT, ValueT]: - return value + return cell.object diff --git a/hikari/internal/routes.py b/hikari/internal/routes.py index 371ed5c128..0a0f3d33c9 100644 --- a/hikari/internal/routes.py +++ b/hikari/internal/routes.py @@ -34,7 +34,6 @@ import attr from hikari import files -from hikari.internal import attr_extensions from hikari.internal import data_binding HASH_SEPARATOR: typing.Final[str] = ";" @@ -46,10 +45,7 @@ } -# This could be frozen, except attrs' docs advise against this for performance -# reasons when using slotted classes. -@attr_extensions.with_copy -@attr.define(hash=True, weakref_slot=False) +@attr.frozen(hash=True, weakref_slot=False) @typing.final class CompiledRoute: """A compiled representation of a route to a specific resource. @@ -111,8 +107,7 @@ def __str__(self) -> str: return f"{self.method} {self.compiled_path}" -@attr_extensions.with_copy -@attr.define(hash=True, init=False, weakref_slot=False) +@attr.frozen(hash=True, init=False, weakref_slot=False) @typing.final class Route: """A template used to create compiled routes for specific parameters. @@ -138,14 +133,15 @@ class Route: """The optional major parameter name combination for this endpoint.""" def __init__(self, method: str, path_template: str) -> None: - self.method = method - self.path_template = path_template + # Since this class is "frozen" we can't use it's defined setattr (this is how attrs handles this fwiw). + object.__setattr__(self, "method", method) + object.__setattr__(self, "path_template", path_template) - self.major_params = None + object.__setattr__(self, "major_params", None) match = PARAM_REGEX.findall(path_template) for major_param_combo in MAJOR_PARAM_COMBOS.keys(): if major_param_combo.issubset(match): - self.major_params = major_param_combo + object.__setattr__(self, "major_params", major_param_combo) break def compile(self, **kwargs: typing.Any) -> CompiledRoute: @@ -181,8 +177,7 @@ def _cdn_valid_formats_converter(values: typing.AbstractSet[str]) -> typing.Froz return frozenset(v.lower() for v in values) -@attr_extensions.with_copy -@attr.define(hash=True, weakref_slot=False) +@attr.frozen(hash=True, weakref_slot=False) @typing.final class CDNRoute: """Route implementation for a CDN resource.""" diff --git a/hikari/invites.py b/hikari/invites.py index dcb0cef638..6bdc831b66 100644 --- a/hikari/invites.py +++ b/hikari/invites.py @@ -40,7 +40,6 @@ from hikari import guilds from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -82,14 +81,11 @@ def __str__(self) -> str: return f"https://discord.gg/{self.code}" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class VanityURL(InviteCode): """A special case invite object, that represents a guild's vanity url.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" code: str = attr.field(hash=True, repr=True) @@ -99,7 +95,7 @@ class VanityURL(InviteCode): """The amount of times this invite has been used.""" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class InviteGuild(guilds.PartialGuild): """Represents the partial data of a guild that is attached to invites.""" @@ -219,14 +215,11 @@ def make_banner_url(self, *, ext: str = "png", size: int = 4096) -> typing.Optio ) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Invite(InviteCode): """Represents an invite that's used to add users to a guild or group dm.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" code: str = attr.field(hash=True, repr=True) @@ -285,7 +278,7 @@ class Invite(InviteCode): """ -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class InviteWithMetadata(Invite): """Extends the base `Invite` object with metadata. diff --git a/hikari/messages.py b/hikari/messages.py index 9b8b680fd8..f7ab4908a0 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -49,7 +49,6 @@ from hikari import traits from hikari import undefined from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -192,8 +191,7 @@ class StickerFormatType(int, enums.Enum): """A lottie sticker.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Attachment(snowflakes.Unique, files.WebResource): """Represents a file attached to a message. @@ -229,8 +227,7 @@ def __str__(self) -> str: return self.filename -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Reaction: """Represents a reaction in a message.""" @@ -247,8 +244,7 @@ def __str__(self) -> str: return str(self.emoji) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Sticker(snowflakes.Unique): """Represents the stickers found attached to messages on Discord.""" @@ -278,8 +274,7 @@ class Sticker(snowflakes.Unique): """The format of this sticker's asset.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MessageActivity: """Represents the activity of a rich presence-enabled message.""" @@ -290,8 +285,7 @@ class MessageActivity: """The party ID of the message activity.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class Mentions: """Description of mentions that exist in the message.""" @@ -320,14 +314,14 @@ def channels_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowf if self.channels is undefined.UNDEFINED: return undefined.UNDEFINED - return list(self.channels.keys()) + return tuple(self.channels.keys()) @property def user_ids(self) -> undefined.UndefinedOr[typing.Sequence[snowflakes.Snowflake]]: if self.users is undefined.UNDEFINED: return undefined.UNDEFINED - return list(self.users.keys()) + return tuple(self.users.keys()) @property def members(self) -> undefined.UndefinedOr[typing.Mapping[snowflakes.Snowflake, guilds.Member]]: @@ -409,8 +403,7 @@ def _map_cache_maybe_discover( return results -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class MessageReference: """Represents information about a referenced message. @@ -418,9 +411,7 @@ class MessageReference: message, pin add messages and replies. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: typing.Optional[snowflakes.Snowflake] = attr.field(repr=True) @@ -441,8 +432,7 @@ class MessageReference: """ -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class MessageApplication(guilds.PartialApplication): """The representation of an application used in messages.""" @@ -498,8 +488,7 @@ def make_cover_image_url(self, *, ext: str = "png", size: int = 4096) -> typing. ) -@attr_extensions.with_copy -@attr.define(kw_only=True, repr=True, eq=False, weakref_slot=False) +@attr.frozen(kw_only=True, repr=True, eq=False, weakref_slot=False) class PartialMessage(snowflakes.Unique): """A message representation containing partially populated information. @@ -514,9 +503,7 @@ class PartialMessage(snowflakes.Unique): nullability. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) @@ -1187,7 +1174,7 @@ async def remove_all_reactions(self, emoji: undefined.UndefinedOr[emojis_.Emojii await self.app.rest.delete_all_reactions_for_emoji(channel=self.channel_id, message=self.id, emoji=emoji) -@attr.define(hash=True, kw_only=True, weakref_slot=False, auto_attribs=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False, auto_attribs=False) class Message(PartialMessage): """Represents a message with all known details.""" diff --git a/hikari/presences.py b/hikari/presences.py index 8e438f04f4..4bf746c18c 100644 --- a/hikari/presences.py +++ b/hikari/presences.py @@ -43,7 +43,6 @@ import attr from hikari import snowflakes -from hikari.internal import attr_extensions from hikari.internal import enums if typing.TYPE_CHECKING: @@ -89,8 +88,7 @@ class ActivityType(int, enums.Enum): """Shows up as `Competing in `.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class ActivityTimestamps: """The datetimes for the start and/or end of an activity session.""" @@ -101,8 +99,7 @@ class ActivityTimestamps: """When this activity's session will end, if applicable.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class ActivityParty: """Used to represent activity groups of users.""" @@ -116,8 +113,7 @@ class ActivityParty: """Maximum size of this party, if applicable.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class ActivityAssets: """Used to represent possible assets for an activity.""" @@ -134,8 +130,7 @@ class ActivityAssets: """The text that'll appear when hovering over the small image, if set.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class ActivitySecret: """The secrets used for interacting with an activity party.""" @@ -176,8 +171,7 @@ class ActivityFlag(enums.Flag): # TODO: add strict type checking to gateway for this type in an invariant way. -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class Activity: """Represents a regular activity that can be associated with a presence.""" @@ -194,7 +188,7 @@ def __str__(self) -> str: return self.name -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class RichActivity(Activity): """Represents a rich activity that can be associated with a presence.""" @@ -254,8 +248,7 @@ class Status(str, enums.Enum): """Offline or invisible/grey.""" -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class ClientStatus: """The client statuses for this member.""" @@ -269,14 +262,11 @@ class ClientStatus: """The status of the target user's web session.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class MemberPresence: """Used to represent a guild member's presence.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" user_id: snowflakes.Snowflake = attr.field(repr=True, hash=True) diff --git a/hikari/sessions.py b/hikari/sessions.py index 5dd9c3b7d0..1e1f95d21b 100644 --- a/hikari/sessions.py +++ b/hikari/sessions.py @@ -30,15 +30,13 @@ import attr -from hikari.internal import attr_extensions from hikari.internal import time if typing.TYPE_CHECKING: import datetime -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class SessionStartLimit: """Used to represent information about the current session start limits.""" @@ -78,8 +76,7 @@ def reset_at(self) -> datetime.datetime: return self._created_at + self.reset_after -@attr_extensions.with_copy -@attr.define(hash=False, kw_only=True, weakref_slot=False) +@attr.frozen(hash=False, kw_only=True, weakref_slot=False) class GatewayBot: """Used to represent gateway information for the connected bot.""" diff --git a/hikari/snowflakes.py b/hikari/snowflakes.py index 19d90c1bab..e26bc928db 100644 --- a/hikari/snowflakes.py +++ b/hikari/snowflakes.py @@ -124,11 +124,6 @@ def id(self) -> Snowflake: The snowflake ID of this object. """ - # TODO: make immutable interface, as this is a major risk to consistent hash codes. - @id.setter - def id(self, value: Snowflake) -> None: - """Set the ID on this entity.""" - @property def created_at(self) -> datetime.datetime: """When the object was created.""" diff --git a/hikari/templates.py b/hikari/templates.py index 83748e6247..f25a2b28c5 100644 --- a/hikari/templates.py +++ b/hikari/templates.py @@ -31,7 +31,6 @@ import attr from hikari import guilds -from hikari.internal import attr_extensions if typing.TYPE_CHECKING: import datetime @@ -43,8 +42,7 @@ from hikari import users -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class TemplateRole(guilds.PartialRole): """The partial role object attached to `Template`.""" @@ -70,8 +68,7 @@ class TemplateRole(guilds.PartialRole): """Whether this role can be mentioned by all regardless of permissions.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class TemplateGuild(guilds.PartialGuild): """The partial guild object attached to `Template`.""" @@ -142,8 +139,7 @@ class TemplateGuild(guilds.PartialGuild): """ -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Template: """Represents a template used for creating guilds.""" diff --git a/hikari/users.py b/hikari/users.py index a51bcd0416..86c63a967f 100644 --- a/hikari/users.py +++ b/hikari/users.py @@ -34,7 +34,6 @@ from hikari import snowflakes from hikari import undefined from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -539,8 +538,7 @@ def make_avatar_url(self, *, ext: typing.Optional[str] = None, size: int = 4096) ) -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class PartialUserImpl(PartialUser): """Implementation for partial information about a user. @@ -551,9 +549,7 @@ class PartialUserImpl(PartialUser): id: snowflakes.Snowflake = attr.field(hash=True, repr=True) """The ID of this user.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """Reference to the client application that models may use for procedures.""" discriminator: undefined.UndefinedOr[str] = attr.field(eq=False, hash=False, repr=True) @@ -609,7 +605,7 @@ def __str__(self) -> str: return f"{self.username}#{self.discriminator}" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class UserImpl(PartialUserImpl, User): """Concrete implementation of user information.""" @@ -636,7 +632,7 @@ class UserImpl(PartialUserImpl, User): """The public flags for this user.""" -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class OwnUser(UserImpl): """Represents a user with extended OAuth2 information.""" diff --git a/hikari/voices.py b/hikari/voices.py index 9f5f19b69f..8ec100d48e 100644 --- a/hikari/voices.py +++ b/hikari/voices.py @@ -30,8 +30,6 @@ import attr -from hikari.internal import attr_extensions - if typing.TYPE_CHECKING: import datetime @@ -40,14 +38,11 @@ from hikari import traits -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class VoiceState: """Represents a user's voice connection status.""" - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" channel_id: typing.Optional[snowflakes.Snowflake] = attr.field(eq=False, hash=False, repr=True) @@ -100,8 +95,7 @@ class VoiceState: """ -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class VoiceRegion: """Represents a voice region server.""" diff --git a/hikari/webhooks.py b/hikari/webhooks.py index f9c2eb6a78..df16757a3b 100644 --- a/hikari/webhooks.py +++ b/hikari/webhooks.py @@ -33,7 +33,6 @@ from hikari import snowflakes from hikari import undefined from hikari import urls -from hikari.internal import attr_extensions from hikari.internal import enums from hikari.internal import routes @@ -59,8 +58,7 @@ class WebhookType(int, enums.Enum): """Channel Follower webhook.""" -@attr_extensions.with_copy -@attr.define(hash=True, kw_only=True, weakref_slot=False) +@attr.frozen(hash=True, kw_only=True, weakref_slot=False) class Webhook(snowflakes.Unique): """Represents a webhook object on Discord. @@ -69,9 +67,7 @@ class Webhook(snowflakes.Unique): send informational messages to specific channels. """ - app: traits.RESTAware = attr.field( - repr=False, eq=False, hash=False, metadata={attr_extensions.SKIP_DEEP_COPY: True} - ) + app: traits.RESTAware = attr.field(repr=False, eq=False, hash=False) """The client application that models may use for procedures.""" id: snowflakes.Snowflake = attr.field(hash=True, repr=True) diff --git a/tests/hikari/events/test_message_events.py b/tests/hikari/events/test_message_events.py index 0a0a8f6cf8..ae6b7802b5 100644 --- a/tests/hikari/events/test_message_events.py +++ b/tests/hikari/events/test_message_events.py @@ -182,8 +182,8 @@ def event(self): def test_guild_id_property(self, event): assert event.guild_id == snowflakes.Snowflake(342123123) - def test_channel_property_when_no_cache_trait(self, event): - event.app = object() + def test_channel_property_when_no_cache_trait(self): + event = message_events.GuildMessageCreateEvent(app=None, message=None, shard=None) assert event.channel is None @@ -195,8 +195,8 @@ def test_channel_property(self, event, guild_channel_impl): assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(9121234) - def test_guild_property_when_no_cache_trait(self, event): - event.app = object() + def test_guild_property_when_no_cache_trait(self): + event = message_events.GuildMessageCreateEvent(app=None, message=None, shard=None) assert event.guild is None @@ -259,8 +259,8 @@ def test_author_property_when_member_none_and_uncached_but_author_defined(self, def test_guild_id_property(self, event): assert event.guild_id == snowflakes.Snowflake(54123123123) - def test_channel_property_when_no_cache_trait(self, event): - event.app = object() + def test_channel_property_when_no_cache_trait(self): + event = message_events.GuildMessageUpdateEvent(app=None, message=None, old_message=None, shard=None) assert event.channel is None @@ -272,8 +272,8 @@ def test_channel_property(self, event, guild_channel_impl): assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(800001066) - def test_guild_property_when_no_cache_trait(self, event): - event.app = object() + def test_guild_property_when_no_cache_trait(self): + event = message_events.GuildMessageUpdateEvent(app=None, message=None, old_message=None, shard=None) assert event.guild is None @@ -326,8 +326,10 @@ def event(self): is_bulk=True, ) - def test_channel_property_when_no_cache_trait(self, event): - event.app = object() + def test_channel_property_when_no_cache_trait(self): + event = message_events.GuildMessageDeleteEvent( + guild_id=None, channel_id=None, app=None, shard=None, message_ids=None, is_bulk=None + ) assert event.channel is None @@ -339,8 +341,10 @@ def test_channel_property(self, event, guild_channel_impl): assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(54213123123) - def test_guild_property_when_no_cache_trait(self, event): - event.app = object() + def test_guild_property_when_no_cache_trait(self): + event = message_events.GuildMessageDeleteEvent( + guild_id=None, channel_id=None, app=None, shard=None, message_ids=None, is_bulk=None + ) assert event.guild is None diff --git a/tests/hikari/events/test_shard_events.py b/tests/hikari/events/test_shard_events.py index 769c66a5be..c5d49e4569 100644 --- a/tests/hikari/events/test_shard_events.py +++ b/tests/hikari/events/test_shard_events.py @@ -38,6 +38,7 @@ def event(self): snowflakes.Snowflake(55): mock.Mock(), snowflakes.Snowflake(99): mock.Mock(), snowflakes.Snowflake(455): mock.Mock(), + snowflakes.Snowflake(55555): mock.Mock(), }, chunk_count=1, chunk_index=1, @@ -47,11 +48,7 @@ def event(self): ) def test___getitem___with_slice(self, event): - mock_member_0 = object() - mock_member_1 = object() - event.members = {1: object(), 55: object(), 99: mock_member_0, 455: object(), 5444: mock_member_1} - - assert event[2:5:2] == (mock_member_0, mock_member_1) + assert event[2:5:2] == (event.members[99], event.members[55555]) def test___getitem___with_valid_index(self, event): mock_member = object() @@ -66,17 +63,13 @@ def test___getitem___with_invalid_index(self, event): assert event[123] def test___iter___(self, event): - member_0 = mock.Mock() - member_1 = mock.Mock() - member_2 = mock.Mock() - - event.members = { - snowflakes.Snowflake(1): member_0, - snowflakes.Snowflake(2): member_1, - snowflakes.Snowflake(3): member_2, - } - - assert list(event) == [member_0, member_1, member_2] + assert list(event) == [ + event.members[1], + event.members[55], + event.members[99], + event.members[455], + event.members[55555], + ] def test___len___(self, event): - assert len(event) == 4 + assert len(event) == 5 diff --git a/tests/hikari/events/test_typing_events.py b/tests/hikari/events/test_typing_events.py index 6f557d061e..55cb6b2bee 100644 --- a/tests/hikari/events/test_typing_events.py +++ b/tests/hikari/events/test_typing_events.py @@ -65,8 +65,10 @@ def event(self): user=mock.Mock(id=456), ) - def test_channel_when_no_cache(self, event): - event.app = object() + def test_channel_when_no_cache(self): + event = typing_events.GuildTypingEvent( + channel_id=None, timestamp=None, shard=None, app=None, guild_id=None, user=None + ) assert event.channel is None @@ -78,8 +80,10 @@ def test_channel(self, event, guild_channel_impl): assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(123) - async def test_guild_when_no_cache(self, event): - event.app = object() + async def test_guild_when_no_cache(self): + event = typing_events.GuildTypingEvent( + channel_id=None, timestamp=None, shard=None, app=None, guild_id=None, user=None + ) assert event.guild is None @@ -139,8 +143,8 @@ def event(self): user_id=456, ) - async def test_user_when_no_cache(self, event): - event.app = object() + async def test_user_when_no_cache(self): + event = typing_events.DMTypingEvent(app=None, shard=None, user_id=None, timestamp=None, channel_id=None) assert event.user is None diff --git a/tests/hikari/events/test_voice_events.py b/tests/hikari/events/test_voice_events.py index f8eb3e915f..5afd3f7b5b 100644 --- a/tests/hikari/events/test_voice_events.py +++ b/tests/hikari/events/test_voice_events.py @@ -53,6 +53,6 @@ def event(self): def test_endpoint_property(self, event): assert event.endpoint == "wss://voice.discord.com:123" - def test_endpoint_property_when_raw_endpoint_is_None(self, event): - event.raw_endpoint = None + def test_endpoint_property_when_raw_endpoint_is_None(self): + event = voice_events.VoiceServerUpdateEvent(app=None, shard=None, guild_id=None, token=None, raw_endpoint=None) assert event.endpoint is None diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index 02809ab696..df4c0b8e14 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -84,7 +84,7 @@ def test_clear(self, cache_impl): cache_impl._create_cache.assert_called_once_with() def test__build_emoji(self, cache_impl): - mock_user = mock.MagicMock(users.User) + mock_user = object() emoji_data = cache_utilities.KnownCustomEmojiData( id=snowflakes.Snowflake(1233534234), name="OKOKOKOKOK", @@ -103,8 +103,7 @@ def test__build_emoji(self, cache_impl): assert emoji.id == snowflakes.Snowflake(1233534234) assert emoji.name == "OKOKOKOKOK" assert emoji.guild_id == snowflakes.Snowflake(65234123) - assert emoji.user == mock_user - assert emoji.user is not mock_user + assert emoji.user is mock_user assert emoji.is_animated is True assert emoji.is_colons_required is False assert emoji.is_managed is False @@ -598,7 +597,7 @@ def test_delete_guild_for_unknown_record(self, cache_impl): assert cache_impl._guild_entries == {snowflakes.Snowflake(354123): cache_utilities.GuildRecord()} def test_get_guild_first_tries_get_available_guilds(self, cache_impl): - mock_guild = mock.MagicMock(guilds.GatewayGuild) + mock_guild = object() cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), @@ -608,11 +607,10 @@ def test_get_guild_first_tries_get_available_guilds(self, cache_impl): cached_guild = cache_impl.get_guild(StubModel(543123)) - assert cached_guild == mock_guild - assert cache_impl is not mock_guild + assert cached_guild is mock_guild def test_get_guild_then_tries_get_unavailable_guilds(self, cache_impl): - mock_guild = mock.MagicMock(guilds.GatewayGuild) + mock_guild = object() cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(543123): cache_utilities.GuildRecord(is_available=True), @@ -622,11 +620,10 @@ def test_get_guild_then_tries_get_unavailable_guilds(self, cache_impl): cached_guild = cache_impl.get_guild(StubModel(54234123)) - assert cached_guild == mock_guild - assert cache_impl is not mock_guild + assert cached_guild is mock_guild def test_get_available_guild_for_known_guild_when_available(self, cache_impl): - mock_guild = mock.MagicMock(guilds.GatewayGuild) + mock_guild = object() cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), @@ -636,8 +633,7 @@ def test_get_available_guild_for_known_guild_when_available(self, cache_impl): cached_guild = cache_impl.get_available_guild(StubModel(543123)) - assert cached_guild == mock_guild - assert cache_impl is not mock_guild + assert cached_guild is mock_guild def test_get_available_guild_for_known_guild_when_unavailable(self, cache_impl): mock_guild = mock.Mock(guilds.GatewayGuild) @@ -674,7 +670,7 @@ def test_get_available_guild_for_unknown_guild_record(self, cache_impl): assert result is None def test_get_unavailable_guild_for_known_guild_when_unavailable(self, cache_impl): - mock_guild = mock.MagicMock(guilds.GatewayGuild) + mock_guild = object() cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), @@ -684,8 +680,7 @@ def test_get_unavailable_guild_for_known_guild_when_unavailable(self, cache_impl cached_guild = cache_impl.get_unavailable_guild(StubModel(452131)) - assert cached_guild == mock_guild - assert cache_impl is not mock_guild + assert cached_guild is mock_guild def test_get_unavailable_guild_for_known_guild_when_available(self, cache_impl): mock_guild = mock.Mock(guilds.GatewayGuild) @@ -807,13 +802,12 @@ def test_get_unavailable_guilds_view_when_no_guilds_cached(self, cache_impl): assert result == {} def test_set_guild(self, cache_impl): - mock_guild = mock.MagicMock(guilds.GatewayGuild, id=snowflakes.Snowflake(5123123)) + mock_guild = mock.Mock(guilds.GatewayGuild, id=snowflakes.Snowflake(5123123)) cache_impl.set_guild(mock_guild) assert 5123123 in cache_impl._guild_entries - assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].guild == mock_guild - assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].guild is not mock_guild + assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].guild is mock_guild assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].is_available is True def test_set_guild_availability_for_cached_guild(self, cache_impl): @@ -865,8 +859,8 @@ def test_update_guild_channel(self, cache_impl): ... def test__build_invite(self, cache_impl): - mock_inviter = mock.MagicMock(users.User) - mock_target_user = mock.MagicMock(users.User) + mock_inviter = object() + mock_target_user = object() invite_data = cache_utilities.InviteData( code="okokok", guild_id=snowflakes.Snowflake(965234), @@ -889,10 +883,8 @@ def test__build_invite(self, cache_impl): assert invite.guild_id == snowflakes.Snowflake(965234) assert invite.channel is None assert invite.channel_id == snowflakes.Snowflake(87345234) - assert invite.inviter == mock_inviter - assert invite.target_user == mock_target_user - assert invite.inviter is not mock_inviter - assert invite.target_user is not mock_target_user + assert invite.inviter is mock_inviter + assert invite.target_user is mock_target_user assert invite.target_user_type is invites.TargetUserType.STREAM assert invite.approximate_active_member_count is None assert invite.approximate_member_count is None @@ -1401,24 +1393,22 @@ def test_delete_me_for_unknown_me(self, cache_impl): assert cache_impl._me is None def test_get_me_for_known_me(self, cache_impl): - mock_own_user = mock.MagicMock(users.OwnUser) + mock_own_user = object() cache_impl._me = mock_own_user cached_me = cache_impl.get_me() - assert cached_me == mock_own_user - assert cached_me is not mock_own_user + assert cached_me is mock_own_user def test_get_me_for_unknown_me(self, cache_impl): assert cache_impl.get_me() is None def test_set_me(self, cache_impl): - mock_own_user = mock.MagicMock(users.OwnUser) + mock_own_user = object() cache_impl.set_me(mock_own_user) - assert cache_impl._me == mock_own_user - assert cache_impl._me is not mock_own_user + assert cache_impl._me is mock_own_user def test_update_me_for_cached_me(self, cache_impl): mock_cached_own_user = mock.MagicMock(users.OwnUser) @@ -1439,7 +1429,7 @@ def test_update_me_for_uncached_me(self, cache_impl): assert cache_impl._me == mock_own_user def test__build_member(self, cache_impl): - mock_user = mock.MagicMock(users.User) + mock_user = object() member_data = cache_utilities.MemberData( user=cache_utilities.RefCell(mock_user), guild_id=snowflakes.Snowflake(6434435234), @@ -1454,8 +1444,7 @@ def test__build_member(self, cache_impl): member = cache_impl._build_member(cache_utilities.RefCell(member_data)) - assert member.user == mock_user - assert member.user is not mock_user + assert member.user is mock_user assert member.guild_id == 6434435234 assert member.nickname == "NICK" assert member.role_ids == (snowflakes.Snowflake(65234), snowflakes.Snowflake(654234123)) @@ -1813,8 +1802,6 @@ def test_set_member(self, cache_impl): assert member_entry.object.guild_id == 67345234 assert member_entry.object.nickname == "A NICK LOL" assert member_entry.object.role_ids == (65345234, 123123) - assert member_entry.object.role_ids is not member_model.role_ids - assert isinstance(member_entry.object.role_ids, tuple) assert member_entry.object.joined_at == datetime.datetime( 2020, 7, 15, 23, 30, 59, 501602, tzinfo=datetime.timezone.utc ) @@ -1998,7 +1985,7 @@ def test_get_users_view_for_empty_user_cache(self, cache_impl): assert cache_impl.get_users_view() == {} def test__set_user(self, cache_impl): - mock_user = mock.MagicMock(users.User, id=snowflakes.Snowflake(6451234123)) + mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(6451234123)) cache_impl._user_entries = collections.FreezableDict( {snowflakes.Snowflake(542143): mock.Mock(cache_utilities.RefCell)} ) @@ -2007,11 +1994,10 @@ def test__set_user(self, cache_impl): assert result is cache_impl._user_entries[snowflakes.Snowflake(6451234123)] assert 6451234123 in cache_impl._user_entries - assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object == mock_user - assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is not mock_user + assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is mock_user def test__set_user_carries_over_ref_count(self, cache_impl): - mock_user = mock.MagicMock(users.User, id=snowflakes.Snowflake(6451234123)) + mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(6451234123)) cache_impl._user_entries = collections.FreezableDict( { snowflakes.Snowflake(542143): mock.Mock(cache_utilities.RefCell), @@ -2023,8 +2009,7 @@ def test__set_user_carries_over_ref_count(self, cache_impl): assert result is cache_impl._user_entries[snowflakes.Snowflake(6451234123)] assert 6451234123 in cache_impl._user_entries - assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object == mock_user - assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is not mock_user + assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is mock_user assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].ref_count == 42 def test__build_voice_state(self, cache_impl): @@ -2330,7 +2315,7 @@ def test_update_voice_state(self, cache_impl): ) def test__build_message(self, cache_impl): - mock_author = mock.MagicMock(users.User) + mock_author = object() mock_member = object() member_data = mock.Mock(build_entity=mock.Mock(return_value=mock_member)) mock_channel = mock.MagicMock() @@ -2341,14 +2326,14 @@ def test__build_message(self, cache_impl): channels={snowflakes.Snowflake(4444): mock_channel}, everyone=True, ) - mock_attachment = mock.MagicMock(messages.Attachment) - mock_embed_field = mock.MagicMock(embeds.EmbedField) - mock_embed = mock.MagicMock(embeds.Embed, fields=(mock_embed_field,)) - mock_sticker = mock.MagicMock(messages.Sticker) - mock_reaction = mock.MagicMock(messages.Reaction) - mock_activity = mock.MagicMock(messages.MessageActivity) - mock_applcation = mock.MagicMock(messages.MessageApplication) - mock_reference = mock.MagicMock(messages.MessageReference) + mock_attachment = object() + mock_embed_field = object() + mock_embed = mock.Mock(embeds.Embed, fields=(mock_embed_field,)) + mock_sticker = object() + mock_reaction = object() + mock_activity = object() + mock_applcation = object() + mock_reference = object() mock_referenced_message = object() mock_referenced_message_data = mock.Mock( cache_utilities.MessageData, build_entity=mock.Mock(return_value=mock_referenced_message) @@ -2385,8 +2370,7 @@ def test__build_message(self, cache_impl): assert result.id == 32123123 assert result.channel_id == 3123123123 assert result.guild_id == 5555555 - assert result.author == mock_author - assert result.author is not mock_author + assert result.author is mock_author assert result.member is mock_member assert result.content == "OKOKOK" assert result.timestamp == datetime.datetime(2020, 7, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc) @@ -2423,12 +2407,9 @@ def test__build_message(self, cache_impl): assert result.is_pinned is False assert result.webhook_id == 3123123 assert result.type is messages.MessageType.REPLY - assert result.activity == mock_activity - assert result.activity is not mock_activity - assert result.application == mock_applcation - assert result.application is not mock_applcation - assert result.message_reference == mock_reference - assert result.message_reference is not mock_reference + assert result.activity is mock_activity + assert result.application is mock_applcation + assert result.message_reference is mock_reference assert result.flags == messages.MessageFlag.CROSSPOSTED assert result.stickers == (mock_sticker,) assert result.nonce == "aNonce" diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index df2bcc494a..1c64e3c072 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -125,7 +125,9 @@ def test_deserialize_own_connection(self, entity_factory_impl, own_connection_pa assert own_connection.name == "FS" assert own_connection.type == "twitter" assert own_connection.is_revoked is False - assert own_connection.integrations == [entity_factory_impl.deserialize_partial_integration(partial_integration)] + assert own_connection.integrations == ( + entity_factory_impl.deserialize_partial_integration(partial_integration), + ) assert own_connection.is_verified is True assert own_connection.is_friend_sync_enabled is False assert own_connection.is_activity_visible is True @@ -139,7 +141,7 @@ def test_deserialize_own_connection_when_integrations_is_None(self, entity_facto assert own_connection.name == "FS" assert own_connection.type == "twitter" assert own_connection.is_revoked is False - assert own_connection.integrations == [] + assert own_connection.integrations == () assert own_connection.is_verified is True assert own_connection.is_friend_sync_enabled is False assert own_connection.is_activity_visible is True @@ -162,7 +164,7 @@ def test_deserialize_own_guild(self, entity_factory_impl, mock_app, own_guild_pa assert own_guild.id == 152559372126519269 assert own_guild.name == "Isopropyl" assert own_guild.icon_hash == "d4a983885dsaa7691ce8bcaaf945a" - assert own_guild.features == [guild_models.GuildFeature.DISCOVERABLE, "FORCE_RELAY"] + assert own_guild.features == (guild_models.GuildFeature.DISCOVERABLE, "FORCE_RELAY") assert own_guild.is_owner is False assert own_guild.my_permissions == permission_models.Permissions(2147483647) @@ -224,7 +226,7 @@ def test_deserialize_application( assert application.is_bot_public is True assert application.is_bot_code_grant_required is False assert application.owner == entity_factory_impl.deserialize_user(owner_payload) - assert application.rpc_origins == ["127.0.0.0"] + assert application.rpc_origins == ("127.0.0.0",) assert application.summary == "not a blank string" assert ( application.public_key @@ -243,7 +245,7 @@ def test_deserialize_application( assert len(application.team.members) == 1 member = application.team.members[115590097100865541] assert member.membership_state == application_models.TeamMembershipState.INVITED - assert member.permissions == ["*"] + assert member.permissions == ("*",) assert member.team_id == 209333111222 assert member.user == entity_factory_impl.deserialize_user(user_payload) assert isinstance(member, application_models.TeamMember) @@ -336,7 +338,7 @@ def test_deserialize_authorization_information( assert authorization_information.expires_at == datetime.datetime( 2021, 2, 1, 18, 3, 20, 888000, tzinfo=datetime.timezone.utc ) - assert authorization_information.scopes == ["identify", "guilds", "applications.commands.update"] + assert authorization_information.scopes == ("identify", "guilds", "applications.commands.update") assert authorization_information.user == entity_factory_impl.deserialize_user(user_payload) def test_deserialize_authorization_information_with_unset_fields( @@ -374,10 +376,10 @@ def test_deserialize_partial_token(self, entity_factory_impl, client_credentials assert partial_token.access_token == "6qrZcUqja7812RVdnEKjpzOL4CvHBFG" assert partial_token.token_type is application_models.TokenType.BEARER assert partial_token.expires_in == datetime.timedelta(weeks=1) - assert partial_token.scopes == [ + assert partial_token.scopes == ( application_models.OAuth2Scope.IDENTIFY, application_models.OAuth2Scope.CONNECTIONS, - ] + ) assert isinstance(partial_token, application_models.PartialOAuth2Token) @pytest.fixture() @@ -400,10 +402,10 @@ def test_deserialize_authorization_token( assert access_token.token_type is application_models.TokenType.BEARER assert access_token.guild == entity_factory_impl.deserialize_rest_guild(deserialize_rest_guild_payload) assert access_token.access_token == "zMndOe7jFLXGawdlxMOdNvXjjOce5X" - assert access_token.scopes == [ + assert access_token.scopes == ( application_models.OAuth2Scope.BOT, application_models.OAuth2Scope.WEBHOOK_INCOMING, - ] + ) assert access_token.expires_in == datetime.timedelta(weeks=4) assert access_token.refresh_token == "mgp8qnvBwJcmadwgCYKyYD5CAzGAX4" assert access_token.webhook == entity_factory_impl.deserialize_webhook(webhook_payload) @@ -433,7 +435,7 @@ def test_deserialize_implicit_token(self, entity_factory_impl, implicit_token_pa assert implicit_token.access_token == "RTfP0OK99U3kbRtHOoKLmJbOn45PjL" assert implicit_token.token_type is application_models.TokenType.BASIC assert implicit_token.expires_in == datetime.timedelta(weeks=2) - assert implicit_token.scopes == [application_models.OAuth2Scope.IDENTIFY] + assert implicit_token.scopes == (application_models.OAuth2Scope.IDENTIFY,) assert implicit_token.state == "15773059ghq9183habn" assert isinstance(implicit_token, application_models.OAuth2ImplicitToken) @@ -664,7 +666,7 @@ def test_deserialize_audit_log_with_unset_or_unknown_fields(self, entity_factory assert len(audit_log.entries) == 1 entry = audit_log.entries[694026906592477214] - assert entry.changes == [] + assert entry.changes == () assert entry.target_id is None assert entry.user_id is None assert entry.action_type == 69 @@ -682,8 +684,12 @@ def test_deserialize_audit_log_with_unhandled_change_key(self, entity_factory_im assert len(entry.changes) == 1 change = entry.changes[0] assert change.key == audit_log_models.AuditLogChangeKey.NAME - assert change.new_value == [{"id": "568651298858074123", "name": "Casual"}] - assert change.old_value == [{"id": "123123123312312", "name": "aRole"}] + assert change.new_value == [ + {"id": "568651298858074123", "name": "Casual"}, + ] + assert change.old_value == [ + {"id": "123123123312312", "name": "aRole"}, + ] def test_deserialize_audit_log_with_change_key_unknown(self, entity_factory_impl, audit_log_payload): # Unset fields @@ -1671,7 +1677,7 @@ def test_deserialize_known_custom_emoji( assert emoji.guild_id == 1235123 assert emoji.name == "testing" assert emoji.is_animated is False - assert emoji.role_ids == [123, 456] + assert emoji.role_ids == (123, 456) assert emoji.user == entity_factory_impl.deserialize_user(user_payload) assert emoji.is_colons_required is True assert emoji.is_managed is False @@ -1826,7 +1832,7 @@ def test_deserialize_member(self, entity_factory_impl, mock_app, member_payload, assert member.guild_id == 76543325 assert member.user == entity_factory_impl.deserialize_user(user_payload) assert member.nickname == "foobarbaz" - assert member.role_ids == [11111, 22222, 33333, 44444, 76543325] + assert member.role_ids == (11111, 22222, 33333, 44444, 76543325) assert member.joined_at == datetime.datetime(2015, 4, 26, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) assert member.premium_since == datetime.datetime(2019, 5, 17, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) assert member.is_deaf is False @@ -1846,7 +1852,7 @@ def test_deserialize_member_when_guild_id_already_in_role_array( assert member.guild_id == 76543325 assert member.user == entity_factory_impl.deserialize_user(user_payload) assert member.nickname == "foobarbaz" - assert member.role_ids == [11111, 22222, 76543325, 33333, 44444] + assert member.role_ids == (11111, 22222, 76543325, 33333, 44444) assert member.joined_at == datetime.datetime(2015, 4, 26, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) assert member.premium_since == datetime.datetime(2019, 5, 17, 6, 26, 56, 936000, tzinfo=datetime.timezone.utc) assert member.is_deaf is False @@ -2123,7 +2129,7 @@ def test_deserialize_guild_preview( assert guild_preview.id == 152559372126519269 assert guild_preview.name == "Isopropyl" assert guild_preview.icon_hash == "d4a983885dsaa7691ce8bcaaf945a" - assert guild_preview.features == [guild_models.GuildFeature.DISCOVERABLE, "FORCE_RELAY"] + assert guild_preview.features == (guild_models.GuildFeature.DISCOVERABLE, "FORCE_RELAY") assert guild_preview.splash_hash == "dsa345tfcdg54b" assert guild_preview.discovery_splash_hash == "lkodwaidi09239uid" assert guild_preview.emojis == { @@ -2214,12 +2220,12 @@ def test_deserialize_rest_guild( assert guild.id == 265828729970753537 assert guild.name == "L33t guild" assert guild.icon_hash == "1a2b3c4d" - assert guild.features == [ + assert guild.features == ( guild_models.GuildFeature.ANIMATED_ICON, guild_models.GuildFeature.MORE_EMOJI, guild_models.GuildFeature.NEWS, "SOME_UNDOCUMENTED_FEATURE", - ] + ) assert guild.splash_hash == "0ff0ff0ff" assert guild.discovery_splash_hash == "famfamFAMFAMfam" assert guild.owner_id == 6969696 @@ -2433,12 +2439,12 @@ def test_deserialize_gateway_guild( assert guild.id == 265828729970753537 assert guild.name == "L33t guild" assert guild.icon_hash == "1a2b3c4d" - assert guild.features == [ + assert guild.features == ( guild_models.GuildFeature.ANIMATED_ICON, guild_models.GuildFeature.MORE_EMOJI, guild_models.GuildFeature.NEWS, "SOME_UNDOCUMENTED_FEATURE", - ] + ) assert guild.splash_hash == "0ff0ff0ff" assert guild.discovery_splash_hash == "famfamFAMFAMfam" assert guild.owner_id == 6969696 @@ -2688,7 +2694,7 @@ def test_deserialize_invite( assert invite.guild.id == 56188492224814744 assert invite.guild.name == "Testin' Your Scene" assert invite.guild.icon_hash == "bb71f469c158984e265093a81b3397fb" - assert invite.guild.features == ["FORCE_RELAY"] + assert invite.guild.features == ("FORCE_RELAY",) assert invite.guild.splash_hash == "aSplashForSure" assert invite.guild.banner_hash == "aBannerForSure" assert invite.guild.description == "Describe me cute kitty." @@ -2804,7 +2810,7 @@ def test_deserialize_invite_with_metadata( assert invite_with_metadata.guild.id == 56188492224814744 assert invite_with_metadata.guild.name == "Testin' Your Scene" assert invite_with_metadata.guild.icon_hash == "bb71f469c158984e265093a81b3397fb" - assert invite_with_metadata.guild.features == ["FORCE_RELAY"] + assert invite_with_metadata.guild.features == ("FORCE_RELAY",) assert invite_with_metadata.guild.splash_hash == "aSplashForSure" assert invite_with_metadata.guild.banner_hash == "aBannerForSure" assert invite_with_metadata.guild.description == "Describe me cute kitty." @@ -2979,9 +2985,9 @@ def test_deserialize_partial_message( ) assert partial_message.is_tts is True assert partial_message.mentions.everyone is True - assert partial_message.mentions.user_ids == [5678] - assert partial_message.mentions.role_ids == [987] - assert partial_message.mentions.channels_ids == [456] + assert partial_message.mentions.user_ids == (5678,) + assert partial_message.mentions.role_ids == (987,) + assert partial_message.mentions.channels_ids == (456,) # Attachment assert len(partial_message.attachments) == 1 attachment = partial_message.attachments[0] @@ -2996,7 +3002,7 @@ def test_deserialize_partial_message( assert isinstance(attachment, message_models.Attachment) expected_embed = entity_factory_impl.deserialize_embed(embed_payload) - assert partial_message.embeds == [expected_embed] + assert partial_message.embeds == (expected_embed,) # Reaction reaction = partial_message.reactions[0] assert reaction.count == 100 @@ -3233,18 +3239,18 @@ def test_deserialize_message_with_null_and_unset_fields( assert message.member is None assert message.edited_timestamp is None assert message.mentions.everyone is True - assert message.mentions.user_ids == [] - assert message.mentions.role_ids == [] - assert message.mentions.channels_ids == [] - assert message.attachments == [] - assert message.embeds == [] - assert message.reactions == [] + assert message.mentions.user_ids == () + assert message.mentions.role_ids == () + assert message.mentions.channels_ids == () + assert message.attachments == () + assert message.embeds == () + assert message.reactions == () assert message.webhook_id is None assert message.activity is None assert message.application is None assert message.message_reference is None assert message.referenced_message is undefined.UNDEFINED - assert message.stickers == [] + assert message.stickers == () assert message.nonce is None def test_deserialize_message_with_other_unset_fields(self, entity_factory_impl, message_payload): @@ -3355,7 +3361,7 @@ def test_deserialize_member_presence( assert presence.client_status.web == presence_models.Status.DO_NOT_DISTURB assert isinstance(presence.client_status, presence_models.ClientStatus) - assert activity.buttons == ["owo", "no"] + assert activity.buttons == ("owo", "no") assert isinstance(presence, presence_models.MemberPresence) def test_deserialize_member_presence_with_unset_fields( @@ -3407,7 +3413,7 @@ def test_deserialize_member_presence_with_unset_activity_fields(self, entity_fac assert activity.secrets is None assert activity.is_instance is None assert activity.flags is None - assert activity.buttons == [] + assert activity.buttons == () def test_deserialize_member_presence_with_null_activity_fields(self, entity_factory_impl, user_payload): presence = entity_factory_impl.deserialize_member_presence( diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index 9eee34915c..e493d6e051 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -331,7 +331,7 @@ def test_deserialize_guild_emojis_update_event(self, event_factory, mock_app, mo assert isinstance(event, guild_events.EmojisUpdateEvent) assert event.app is mock_app assert event.shard is mock_shard - assert event.emojis == [mock_app.entity_factory.deserialize_known_custom_emoji.return_value] + assert event.emojis == (mock_app.entity_factory.deserialize_known_custom_emoji.return_value,) assert event.guild_id == 123431 assert event.old_emojis is mock_old_emojis @@ -897,7 +897,7 @@ def test_deserialize_guild_member_chunk_event_with_optional_fields(self, event_f assert event.members == {4222222: mock_app.entity_factory.deserialize_member.return_value} assert event.chunk_count == 54 assert event.chunk_index == 3 - assert event.not_found == [34212312312, 323123123] + assert event.not_found == (34212312312, 323123123) assert event.presences == {43123123: mock_app.entity_factory.deserialize_member_presence.return_value} assert event.nonce == "OKOKOKOK" @@ -912,7 +912,7 @@ def test_deserialize_guild_member_chunk_event_without_optional_fields(self, even event = event_factory.deserialize_guild_member_chunk_event(mock_shard, mock_payload) - assert event.not_found == [] + assert event.not_found == () assert event.presences == {} assert event.nonce is None diff --git a/tests/hikari/internal/test_attr_extensions.py b/tests/hikari/internal/test_attr_extensions.py deleted file mode 100644 index b906baa1d6..0000000000 --- a/tests/hikari/internal/test_attr_extensions.py +++ /dev/null @@ -1,446 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2020 Nekokatt -# Copyright (c) 2021 davfsa -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -import contextlib -import copy as stdlib_copy - -import attr -import mock - -from hikari.internal import attr_extensions - - -def test_invalidate_shallow_copy_cache(): - attr_extensions._SHALLOW_COPIERS = {int: object(), str: object()} - assert attr_extensions.invalidate_shallow_copy_cache() is None - assert attr_extensions._SHALLOW_COPIERS == {} - - -def test_invalidate_deep_copy_cache(): - attr_extensions._DEEP_COPIERS = {str: object(), int: object(), object: object()} - assert attr_extensions.invalidate_deep_copy_cache() is None - assert attr_extensions._DEEP_COPIERS == {} - - -def test_get_fields_definition(): - @attr.define() - class StubModel: - foo: int = attr.field() - bar: bool = attr.field(init=False) - bam: bool = attr.field(init=False) - _voodoo: str = attr.field() - Bat: bool = attr.field() - - fields = {field.name: field for field in attr.fields(StubModel)} - new_model = attr_extensions.get_fields_definition(StubModel) - assert new_model == ( - [(fields["foo"], "foo"), (fields["_voodoo"], "voodoo"), (fields["Bat"], "Bat")], - [fields["bar"], fields["bam"]], - ) - - -def test_generate_shallow_copier(): - @attr.define() - class StubModel: - _foo: int = attr.field() - baaaa: str = attr.field() - _blam: bool = attr.field() - not_init: int = attr.field(init=False) - no: bytes = attr.field() - - old_model = StubModel(foo=42, baaaa="sheep", blam=True, no=b"okokokok") - old_model.not_init = 54234 - - copier = attr_extensions.generate_shallow_copier(StubModel) - new_model = copier(old_model) - - assert new_model is not old_model - assert new_model._foo is old_model._foo - assert new_model.baaaa is old_model.baaaa - assert new_model._blam is old_model._blam - assert new_model.not_init is old_model.not_init - assert new_model.no is old_model.no - - -def test_generate_shallow_copier_with_init_only_arguments(): - @attr.define() - class StubModel: - _gfd: int = attr.field() - baaaa: str = attr.field() - _blambat: bool = attr.field() - no: bytes = attr.field() - - old_model = StubModel(gfd=42, baaaa="sheep", blambat=True, no=b"okokokok") - - copier = attr_extensions.generate_shallow_copier(StubModel) - new_model = copier(old_model) - - assert new_model is not old_model - assert new_model._gfd is old_model._gfd - assert new_model.baaaa is old_model.baaaa - assert new_model._blambat is old_model._blambat - assert new_model.no is old_model.no - - -def test_generate_shallow_copier_with_only_non_init_attrs(): - @attr.define() - class StubModel: - _gfd: int = attr.field(init=False) - baaaa: str = attr.field(init=False) - _blambat: bool = attr.field(init=False) - no: bytes = attr.field(init=False) - - old_model = StubModel() - old_model._gfd = 42 - old_model.baaaa = "sheep" - old_model._blambat = True - old_model.no = b"okokokok" - - copier = attr_extensions.generate_shallow_copier(StubModel) - new_model = copier(old_model) - - assert new_model is not old_model - assert new_model._gfd is old_model._gfd - assert new_model.baaaa is old_model.baaaa - assert new_model._blambat is old_model._blambat - assert new_model.no is old_model.no - - -def test_generate_shallow_copier_with_no_attributes(): - @attr.define() - class StubModel: - ... - - old_model = StubModel() - - copier = attr_extensions.generate_shallow_copier(StubModel) - new_model = copier(old_model) - - assert new_model is not old_model - assert isinstance(new_model, StubModel) - - -def test_get_or_generate_shallow_copier_for_cached_copier(): - mock_copier = object() - - @attr.define() - class StubModel: - ... - - attr_extensions._SHALLOW_COPIERS = { - type("b", (), {}): object(), - StubModel: mock_copier, - type("a", (), {}): object(), - } - - assert attr_extensions.get_or_generate_shallow_copier(StubModel) is mock_copier - - -def test_get_or_generate_shallow_copier_for_uncached_copier(): - mock_copier = object() - - @attr.define() - class StubModel: - ... - - with mock.patch.object(attr_extensions, "generate_shallow_copier", return_value=mock_copier): - assert attr_extensions.get_or_generate_shallow_copier(StubModel) is mock_copier - - attr_extensions.generate_shallow_copier.assert_called_once_with(StubModel) - - assert attr_extensions._SHALLOW_COPIERS[StubModel] is mock_copier - - -def test_copy_attrs(): - mock_result = object() - mock_copier = mock.Mock(return_value=mock_result) - - @attr.define() - class StubModel: - ... - - model = StubModel() - - with mock.patch.object(attr_extensions, "get_or_generate_shallow_copier", return_value=mock_copier): - assert attr_extensions.copy_attrs(model) is mock_result - - attr_extensions.get_or_generate_shallow_copier.assert_called_once_with(StubModel) - - mock_copier.assert_called_once_with(model) - - -def test_generate_deep_copier(): - @attr.define - class StubBaseClass: - recursor: int = attr.field() - _field: bool = attr.field() - foo: str = attr.field() - end: str = attr.field(init=False) - _blam: bool = attr.field(init=False) - - model = StubBaseClass(recursor=431, field=True, foo="blam") - model.end = "the way" - model._blam = "555555" - old_model_fields = stdlib_copy.copy(model) - copied_recursor = object() - copied_field = object() - copied_foo = object() - copied_end = object() - copied_blam = object() - memo = {123: object()} - - with mock.patch.object( - stdlib_copy, - "deepcopy", - side_effect=[copied_recursor, copied_field, copied_foo, copied_end, copied_blam], - ): - attr_extensions.generate_deep_copier(StubBaseClass)(model, memo) - - stdlib_copy.deepcopy.assert_has_calls( - [ - mock.call(old_model_fields.recursor, memo), - mock.call(old_model_fields._field, memo), - mock.call(old_model_fields.foo, memo), - mock.call(old_model_fields.end, memo), - mock.call(old_model_fields._blam, memo), - ] - ) - - assert model.recursor is copied_recursor - assert model._field is copied_field - assert model.foo is copied_foo - assert model.end is copied_end - assert model._blam is copied_blam - - -def test_generate_deep_copier_with_only_init_attributes(): - @attr.define - class StubBaseClass: - recursor: int = attr.field() - _field: bool = attr.field() - foo: str = attr.field() - - model = StubBaseClass(recursor=431, field=True, foo="blam") - old_model_fields = stdlib_copy.copy(model) - copied_recursor = object() - copied_field = object() - copied_foo = object() - memo = {123: object()} - - with mock.patch.object( - stdlib_copy, - "deepcopy", - side_effect=[copied_recursor, copied_field, copied_foo], - ): - attr_extensions.generate_deep_copier(StubBaseClass)(model, memo) - - stdlib_copy.deepcopy.assert_has_calls( - [ - mock.call(old_model_fields.recursor, memo), - mock.call(old_model_fields._field, memo), - mock.call(old_model_fields.foo, memo), - ] - ) - - assert model.recursor is copied_recursor - assert model._field is copied_field - assert model.foo is copied_foo - - -def test_generate_deep_copier_with_only_non_init_attributes(): - @attr.define - class StubBaseClass: - end: str = attr.field(init=False) - _blam: bool = attr.field(init=False) - - model = StubBaseClass() - model.end = "the way" - model._blam = "555555" - old_model_fields = stdlib_copy.copy(model) - copied_end = object() - copied_blam = object() - memo = {123: object()} - - with mock.patch.object( - stdlib_copy, - "deepcopy", - side_effect=[copied_end, copied_blam], - ): - attr_extensions.generate_deep_copier(StubBaseClass)(model, memo) - - stdlib_copy.deepcopy.assert_has_calls( - [ - mock.call(old_model_fields.end, memo), - mock.call(old_model_fields._blam, memo), - ] - ) - - assert model.end is copied_end - assert model._blam is copied_blam - - -def test_generate_deep_copier_with_no_attributes(): - @attr.define - class StubBaseClass: - ... - - model = StubBaseClass() - memo = {123: object()} - - with mock.patch.object( - stdlib_copy, - "deepcopy", - side_effect=NotImplementedError, - ): - attr_extensions.generate_deep_copier(StubBaseClass)(model, memo) - - stdlib_copy.deepcopy.assert_not_called() - - -def test_get_or_generate_deep_copier_for_cached_function(): - class StubClass: - ... - - mock_copier = object() - attr_extensions._DEEP_COPIERS = {} - - with mock.patch.object(attr_extensions, "generate_deep_copier", return_value=mock_copier): - assert attr_extensions.get_or_generate_deep_copier(StubClass) is mock_copier - - attr_extensions.generate_deep_copier.assert_called_once_with(StubClass) - - assert attr_extensions._DEEP_COPIERS[StubClass] is mock_copier - - -def test_get_or_generate_deep_copier_for_uncached_function(): - class StubClass: - ... - - mock_copier = object() - attr_extensions._DEEP_COPIERS = {StubClass: mock_copier} - - with mock.patch.object(attr_extensions, "generate_deep_copier"): - assert attr_extensions.get_or_generate_deep_copier(StubClass) is mock_copier - - attr_extensions.generate_deep_copier.assert_not_called() - - -def test_deep_copy_attrs_without_memo(): - class StubClass: - ... - - mock_object = StubClass() - mock_result = object() - mock_copier = mock.Mock(mock_result) - - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(attr_extensions, "get_or_generate_deep_copier", return_value=mock_copier)) - stack.enter_context(mock.patch.object(stdlib_copy, "copy", return_value=mock_result)) - - with stack: - assert attr_extensions.deep_copy_attrs(mock_object) is mock_result - - stdlib_copy.copy.assert_called_once_with(mock_object) - attr_extensions.get_or_generate_deep_copier.assert_called_once_with(StubClass) - - mock_copier.assert_called_once_with(mock_result, {id(mock_object): mock_result}) - - -def test_deep_copy_attrs_with_memo(): - class StubClass: - ... - - mock_object = StubClass() - mock_result = object() - mock_copier = mock.Mock(mock_result) - mock_other_object = object() - - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(attr_extensions, "get_or_generate_deep_copier", return_value=mock_copier)) - stack.enter_context(mock.patch.object(stdlib_copy, "copy", return_value=mock_result)) - - with stack: - assert attr_extensions.deep_copy_attrs(mock_object, {1235342: mock_other_object}) is mock_result - - stdlib_copy.copy.assert_called_once_with(mock_object) - attr_extensions.get_or_generate_deep_copier.assert_called_once_with(StubClass) - - mock_copier.assert_called_once_with(mock_result, {id(mock_object): mock_result, 1235342: mock_other_object}) - - -class TestCopyDecorator: - def test___copy__(self): - mock_result = object() - mock_copier = mock.Mock(return_value=mock_result) - - @attr.define() - @attr_extensions.with_copy - class StubClass: - ... - - model = StubClass() - - with mock.patch.object(attr_extensions, "get_or_generate_shallow_copier", return_value=mock_copier): - assert stdlib_copy.copy(model) is mock_result - - attr_extensions.get_or_generate_shallow_copier.assert_called_once_with(StubClass) - - mock_copier.assert_called_once_with(model) - - def test___deep__copy(self): - class CopyingMock(mock.Mock): - def __call__(self, /, *args, **kwargs): - args = list(args) - args[1] = dict(args[1]) - return super().__call__(*args, **kwargs) - - mock_result = object() - mock_copier = CopyingMock(return_value=mock_result) - - @attr.define() - @attr_extensions.with_copy - class StubClass: - ... - - model = StubClass() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(attr_extensions, "get_or_generate_deep_copier", return_value=mock_copier)) - stack.enter_context(mock.patch.object(stdlib_copy, "copy", return_value=mock_result)) - - with stack: - assert stdlib_copy.deepcopy(model) is mock_result - - stdlib_copy.copy.assert_called_once_with(model) - attr_extensions.get_or_generate_deep_copier.assert_called_once_with(StubClass) - - mock_copier.assert_called_once_with(mock_result, {id(model): mock_result}) - - def test_copy_decorator_inheritance(self): - @attr_extensions.with_copy - @attr.define() - class ParentClass: - ... - - class Foo(ParentClass): - ... - - assert Foo.__copy__ == attr_extensions.copy_attrs - assert Foo.__deepcopy__ == attr_extensions.deep_copy_attrs diff --git a/tests/hikari/internal/test_routes.py b/tests/hikari/internal/test_routes.py index 9783105a50..71e24179b6 100644 --- a/tests/hikari/internal/test_routes.py +++ b/tests/hikari/internal/test_routes.py @@ -300,10 +300,9 @@ def test_compile_generates_expected_url(self, base_url, template, format, size_k @pytest.mark.parametrize("size", [64, 256, 2048]) def test_compile_to_file_calls_compile(self, format, size): with mock.patch.object(files, "URL", autospec=files.URL): - route = hikari_test_helpers.mock_class_namespace(routes.CDNRoute, slots_=False)( + route = hikari_test_helpers.mock_class_namespace(routes.CDNRoute, slots_=False, compile=mock.Mock())( "/hello/world", {"png", "jpg"}, sizable=True ) - route.compile = mock.Mock(spec_set=route.compile) route.compile_to_file("https://blep.com", file_format=format, size=size, boop="oyy lumo", nya="weeb") route.compile.assert_called_once_with( "https://blep.com", file_format=format, size=size, boop="oyy lumo", nya="weeb" @@ -313,10 +312,9 @@ def test_compile_to_file_passes_compile_result_to_URL_and_returns_constructed_ur resultant_url_str = "http://blep.com/hello/world/weeb/oyy%20lumo" resultant_url = files.URL("http://blep.com/hello/world/weeb/oyy%20lumo") with mock.patch.object(files, "URL", autospec=files.URL, return_value=resultant_url) as URL: - route = hikari_test_helpers.mock_class_namespace(routes.CDNRoute, slots_=False)( - "/hello/world/{nya}/{boop}", {"png", "jpg"}, sizable=True - ) - route.compile = mock.Mock(spec_set=route.compile, return_value=resultant_url_str) + route = hikari_test_helpers.mock_class_namespace( + routes.CDNRoute, slots_=False, compile=mock.Mock(return_value=resultant_url_str) + )("/hello/world/{nya}/{boop}", {"png", "jpg"}, sizable=True) result = route.compile_to_file("https://blep.com", file_format="png", size=64, boop="oyy lumo", nya="weeb") URL.assert_called_once_with(resultant_url_str) diff --git a/tests/hikari/test_applications.py b/tests/hikari/test_applications.py index 327409f302..17a15b30c8 100644 --- a/tests/hikari/test_applications.py +++ b/tests/hikari/test_applications.py @@ -57,10 +57,6 @@ def test_flags_property(self, model): def test_id_property(self, model): assert model.id is model.user.id - def test_id_setter(self, model): - with pytest.raises(TypeError, match="Cannot mutate the ID of a member"): - model.id = 42 - def test_is_bot_property(self, model): assert model.is_bot is model.user.is_bot @@ -99,15 +95,17 @@ def test_str_operator(self): team = applications.Team(id=696969, app=object(), name="test", icon_hash="", members=[], owner_id=0) assert str(team) == "Team test (696969)" - def test_icon_url_property(self, model): - model.make_icon_url = mock.Mock(return_value="url") + def test_icon_url_property(self): + model = hikari_test_helpers.mock_class_namespace( + applications.Team, init_=False, make_icon_url=mock.Mock(return_value="url") + )() assert model.icon_url == "url" model.make_icon_url.assert_called_once_with() - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None + def test_make_icon_url_when_hash_is_None(self): + model = applications.Team(app=None, id=None, name=None, icon_hash=None, members=None, owner_id=None) with mock.patch.object( routes, "CDN_TEAM_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) @@ -139,15 +137,21 @@ def model(self): cover_image_hash="ahashcover", )() - def test_cover_image_url_property(self, model): - model.make_cover_image_url = mock.Mock(return_value="url") + def test_cover_image_url_property(self): + model = hikari_test_helpers.mock_class_namespace( + applications.Application, init_=False, make_cover_image_url=mock.Mock(return_value="url") + )() assert model.cover_image_url == "url" model.make_cover_image_url.assert_called_once_with() - def test_make_cover_image_url_when_hash_is_None(self, model): - model.cover_image_hash = None + def test_make_cover_image_url_when_hash_is_None(self): + model = hikari_test_helpers.mock_class_namespace( + applications.Application, + init_=False, + cover_image_hash=None, + )() with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) diff --git a/tests/hikari/test_channels.py b/tests/hikari/test_channels.py index 2682a46f90..5c20172587 100644 --- a/tests/hikari/test_channels.py +++ b/tests/hikari/test_channels.py @@ -87,10 +87,11 @@ def test_channel_when_no_cache_trait(self): class TestPermissionOverwrite: def test_unset(self): overwrite = channels.PermissionOverwrite( - type=channels.PermissionOverwriteType.MEMBER, id=snowflakes.Snowflake(1234321) + type=channels.PermissionOverwriteType.MEMBER, + id=snowflakes.Snowflake(1234321), + allow=permissions.Permissions.CREATE_INSTANT_INVITE, + deny=permissions.Permissions.CHANGE_NICKNAME, ) - overwrite.allow = permissions.Permissions.CREATE_INSTANT_INVITE - overwrite.deny = permissions.Permissions.CHANGE_NICKNAME assert overwrite.unset == permissions.Permissions(-67108866) @@ -107,8 +108,10 @@ def model(self, mock_app): def test_str_operator(self, model): assert str(model) == "foo" - def test_str_operator_when_name_is_None(self, model): - model.name = None + def test_str_operator_when_name_is_None(self): + model = hikari_test_helpers.mock_class_namespace( + channels.PartialChannel, init_=False, rename_impl_=False, name=None, id=1234567 + )() assert str(model) == "Unnamed PartialChannel ID 1234567" @@ -137,7 +140,7 @@ def model(self, mock_app): return channels.GroupDMChannel( app=mock_app, id=snowflakes.Snowflake(136134), - name="super cool group dm", + name=None, type=channels.ChannelType.DM, last_message_id=snowflakes.Snowflake(3232), owner_id=snowflakes.Snowflake(1066), @@ -154,11 +157,13 @@ def model(self, mock_app): application_id=None, ) - def test_str_operator(self, model): + def test_str_operator(self): + model = hikari_test_helpers.mock_class_namespace( + channels.GroupDMChannel, init_=False, name="super cool group dm" + )() assert str(model) == "super cool group dm" def test_str_operator_when_name_is_None(self, model): - model.name = None assert str(model) == "GroupDMChannel with: snoop#0420, yeet#1012, nice#6969" def test_icon_url(self): @@ -178,8 +183,8 @@ def test_make_icon_url_without_optional_params(self, model): "https://cdn.discordapp.com/channel-icons/136134/1a2b3c.png?size=4096" ) - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None + def test_make_icon_url_when_hash_is_None(self): + model = hikari_test_helpers.mock_class_namespace(channels.GroupDMChannel, init_=False, icon_hash=None)() assert model.make_icon_url() is None @@ -269,18 +274,11 @@ def model(self, mock_app): parent_id=None, ) - @pytest.mark.parametrize("error", [TypeError, AttributeError, NameError]) - def test_shard_id_property_when_guild_id_error_raised(self, model, error): - class BrokenApp: - def __getattr__(self, name): - if name == "shard_count": - raise error - return mock.Mock() - - model.app = BrokenApp() + def test_shard_id_property_when_not_shard_aware(self): + model = hikari_test_helpers.mock_class_namespace(channels.GuildChannel, init_=False, app=None)() assert model.shard_id is None - def test_shard_id_property_when_guild_id_is_not_None(self, model): + def test_shard_id_property_when_shard_aware(self, model): model.app.shard_count = 3 assert model.shard_id == 2 diff --git a/tests/hikari/test_embeds.py b/tests/hikari/test_embeds.py index d4d23eb441..656477de1f 100644 --- a/tests/hikari/test_embeds.py +++ b/tests/hikari/test_embeds.py @@ -52,13 +52,13 @@ def resource_with_proxy(self): def test_proxy_url(self, resource_with_proxy): assert resource_with_proxy.proxy_url is resource_with_proxy.proxy_resource.url - def test_proxy_url_when_resource_is_none(self, resource_with_proxy): - resource_with_proxy.proxy_resource = None + def test_proxy_url_when_resource_is_none(self): + resource_with_proxy = embeds.EmbedResourceWithProxy(resource=mock.Mock(), proxy_resource=None) assert resource_with_proxy.proxy_url is None def test_proxy_filename(self, resource_with_proxy): assert resource_with_proxy.proxy_filename is resource_with_proxy.proxy_resource.filename - def test_proxy_filename_when_resource_is_none(self, resource_with_proxy): - resource_with_proxy.proxy_resource = None + def test_proxy_filename_when_resource_is_none(self): + resource_with_proxy = embeds.EmbedResourceWithProxy(resource=mock.Mock(), proxy_resource=None) assert resource_with_proxy.proxy_filename is None diff --git a/tests/hikari/test_errors.py b/tests/hikari/test_errors.py index 299e523798..7e53caaf01 100644 --- a/tests/hikari/test_errors.py +++ b/tests/hikari/test_errors.py @@ -79,12 +79,11 @@ def test_str(self, error): assert str(error) == "Bad Request 400: 'raw body' for https://some.url" def test_str_when_status_is_not_HTTPStatus(self, error): - error.status = "SOME STATUS" + error = errors.HTTPResponseError("https://some.url", "SOME STATUS", {}, "raw body") assert str(error) == "Some Status: 'raw body' for https://some.url" - def test_str_when_message_is_not_None(self, error): - error.status = "SOME STATUS" - error.message = "Some message" + def test_str_when_message_is_not_None(self): + error = errors.HTTPResponseError("https://some.url", "SOME STATUS", {}, "raw body", "Some message") assert str(error) == "Some Status: 'Some message' for https://some.url" diff --git a/tests/hikari/test_guilds.py b/tests/hikari/test_guilds.py index e179cad1e8..416c6d0394 100644 --- a/tests/hikari/test_guilds.py +++ b/tests/hikari/test_guilds.py @@ -70,15 +70,17 @@ def model(self): icon_hash="ahashicon", )() - def test_icon_url_property(self, model): - model.make_icon_url = mock.Mock(return_value="url") + def test_icon_url_property(self): + model = hikari_test_helpers.mock_class_namespace( + guilds.PartialApplication, init_=False, make_icon_url=mock.Mock(return_value="url") + )() assert model.icon_url == "url" model.make_icon_url.assert_called_once_with() def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None + model = hikari_test_helpers.mock_class_namespace(guilds.PartialApplication, init_=False, icon_hash=None)() with mock.patch.object( routes, "CDN_APPLICATION_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) @@ -175,10 +177,6 @@ def test_app_property(self, model, mock_user): def test_id_property(self, model, mock_user): assert model.id is mock_user.id - def test_id_setter_property(self, model): - with pytest.raises(TypeError): - model.id = 456 - def test_username_property(self, model, mock_user): assert model.username is mock_user.username @@ -234,23 +232,22 @@ def test_default_avatar_url_property(self, model, mock_user): def test_display_name_property_when_nickname(self, model): assert model.display_name == "davb" - def test_display_name_property_when_no_nickname(self, model, mock_user): - model.nickname = None + def test_display_name_property_when_no_nickname(self, mock_user): + model = hikari_test_helpers.mock_class_namespace(guilds.Member, nickname=None, user=mock_user, init_=False)() assert model.display_name is mock_user.username def test_mention_property_when_nickname(self, model): assert model.mention == "<@!123>" - def test_mention_property_when_no_nickname(self, model, mock_user): - model.nickname = None + def test_mention_property_when_no_nickname(self, mock_user): + model = hikari_test_helpers.mock_class_namespace(guilds.Member, nickname=None, user=mock_user, init_=False)() assert model.mention == mock_user.mention def test_roles(self, model): - role1 = mock.Mock(id=321, position=2) - role2 = mock.Mock(id=654, position=1) - mock_cache_view = {321: role1, 654: role2} + role1 = mock.Mock(id=456, position=2) + role2 = mock.Mock(id=1234, position=1) + mock_cache_view = {456: role1, 1234: role2} model.user.app.cache.get_roles_view_for_guild.return_value = mock_cache_view - model.role_ids = [321, 654] assert model.roles == [role1, role2] @@ -258,10 +255,9 @@ def test_roles(self, model): def test_roles_when_role_ids_not_in_cache(self, model): role1 = mock.Mock(id=123, position=2) - role2 = mock.Mock(id=456, position=1) - mock_cache_view = {123: role1, 456: role2} + role2 = mock.Mock(id=1234, position=1) + mock_cache_view = {123: role1, 1234: role2} model.user.app.cache.get_roles_view_for_guild.return_value = mock_cache_view - model.role_ids = [321, 456] assert model.roles == [role2] @@ -317,7 +313,7 @@ def test_shard_id_property(self, model): assert model.shard_id == 0 def test_shard_id_when_not_shard_aware(self, model): - model.app = object() + model = hikari_test_helpers.mock_class_namespace(guilds.PartialGuild, init_=False, app=None)() assert model.shard_id is None @@ -327,14 +323,15 @@ def test_icon_url(self, model): with mock.patch.object(guilds.PartialGuild, "make_icon_url", return_value=icon): assert model.icon_url is icon - def test_make_icon_url_when_no_hash(self, model): - model.icon_hash = None + def test_make_icon_url_when_no_hash(self): + model = hikari_test_helpers.mock_class_namespace(guilds.PartialGuild, init_=False, icon_hash=None)() assert model.make_icon_url(ext="png", size=2048) is None - def test_make_icon_url_when_format_is_None_and_avatar_hash_is_for_gif(self, model): - model.icon_hash = "a_yeet" - + def test_make_icon_url_when_format_is_None_and_avatar_hash_is_for_gif(self): + model = hikari_test_helpers.mock_class_namespace( + guilds.PartialGuild, init_=False, icon_hash="a_yeet", id=90210 + )() with mock.patch.object( routes, "CDN_GUILD_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -387,7 +384,7 @@ def model(self, mock_app): icon_hash="dis is mah icon hash", name="DAPI", splash_hash="dis is also mah splash hash", - discovery_splash_hash=None, + discovery_splash_hash="okokokok", emojis={}, approximate_active_member_count=12, approximate_member_count=999_283_252_124_633, @@ -401,8 +398,6 @@ def test_splash_url(self, model): assert model.splash_url is splash def test_make_splash_url_when_hash(self, model): - model.splash_hash = "18dnf8dfbakfdh" - with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -411,13 +406,13 @@ def test_make_splash_url_when_hash(self, model): route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, - hash="18dnf8dfbakfdh", + hash="dis is also mah splash hash", size=1024, file_format="url", ) def test_make_splash_url_when_no_hash(self, model): - model.splash_hash = None + model = hikari_test_helpers.mock_class_namespace(guilds.GuildPreview, splash_hash=None, init_=False)() assert model.make_splash_url(ext="png", size=512) is None def test_discovery_splash_url(self, model): @@ -427,8 +422,6 @@ def test_discovery_splash_url(self, model): assert model.discovery_splash_url is discovery_splash def test_make_discovery_splash_url_when_hash(self, model): - model.discovery_splash_hash = "18dnf8dfbakfdh" - with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -437,13 +430,13 @@ def test_make_discovery_splash_url_when_hash(self, model): route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, - hash="18dnf8dfbakfdh", + hash="okokokok", size=2048, file_format="url", ) - def test_make_discovery_splash_url_when_no_hash(self, model): - model.discovery_splash_hash = None + def test_make_discovery_splash_url_when_no_hash(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GuildPreview, init_=False, discovery_splash_hash=None)() assert model.make_discovery_splash_url(ext="png", size=4096) is None @@ -489,8 +482,6 @@ def test_splash_url(self, model): assert model.splash_url is splash def test_make_splash_url_when_hash(self, model): - model.splash_hash = "18dnf8dfbakfdh" - with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -499,13 +490,13 @@ def test_make_splash_url_when_hash(self, model): route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, - hash="18dnf8dfbakfdh", + hash="splash_hash", size=2, file_format="url", ) - def test_make_splash_url_when_no_hash(self, model): - model.splash_hash = None + def test_make_splash_url_when_no_hash(self): + model = hikari_test_helpers.mock_class_namespace(guilds.Guild, splash_hash=None, init_=False)() assert model.make_splash_url(ext="png", size=1024) is None def test_discovery_splash_url(self, model): @@ -515,8 +506,6 @@ def test_discovery_splash_url(self, model): assert model.discovery_splash_url is discovery_splash def test_make_discovery_splash_url_when_hash(self, model): - model.discovery_splash_hash = "18dnf8dfbakfdh" - with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -525,13 +514,13 @@ def test_make_discovery_splash_url_when_hash(self, model): route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, - hash="18dnf8dfbakfdh", + hash="discovery_splash_hash", size=1024, file_format="url", ) - def test_make_discovery_splash_url_when_no_hash(self, model): - model.discovery_splash_hash = None + def test_make_discovery_splash_url_when_no_hash(self): + model = hikari_test_helpers.mock_class_namespace(guilds.Guild, init_=False, discovery_splash_hash=None)() assert model.make_discovery_splash_url(ext="png", size=2048) is None def test_banner_url(self, model): @@ -555,7 +544,7 @@ def test_make_banner_url_when_hash(self, model): ) def test_make_banner_url_when_no_hash(self, model): - model.banner_hash = None + model = hikari_test_helpers.mock_class_namespace(guilds.Guild, init_=False, banner_hash=None)() assert model.make_banner_url(ext="png", size=2048) is None @@ -600,15 +589,19 @@ def model(self, mock_app): max_members=100, ) - def test_get_emoji(self, model): + def test_get_emoji(self): emoji = object() - model._emojis = {snowflakes.Snowflake(123): emoji} + model = hikari_test_helpers.mock_class_namespace( + guilds.RESTGuild, _emojis={snowflakes.Snowflake(123): emoji}, init_=False + )() assert model.get_emoji(123) is emoji - def test_get_role(self, model): + def test_get_role(self): role = object() - model._roles = {snowflakes.Snowflake(123): role} + model = hikari_test_helpers.mock_class_namespace( + guilds.RESTGuild, _roles={snowflakes.Snowflake(123): role}, init_=False + )() assert model.get_role(123) is role @@ -655,104 +648,104 @@ def test_channels(self, model): assert model.channels is model.app.cache.get_guild_channels_view_for_guild.return_value model.app.cache.get_guild_channels_view_for_guild.assert_called_once_with(123) - def test_channels_when_no_cache_trait(self, model): - model.app = object() + def test_channels_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.channels == {} def test_emojis(self, model): assert model.emojis is model.app.cache.get_emojis_view_for_guild.return_value model.app.cache.get_emojis_view_for_guild.assert_called_once_with(123) - def test_emojis_when_no_cache_trait(self, model): - model.app = object() + def test_emojis_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.emojis == {} def test_members(self, model): assert model.members is model.app.cache.get_members_view_for_guild.return_value model.app.cache.get_members_view_for_guild.assert_called_once_with(123) - def test_members_when_no_cache_trait(self, model): - model.app = object() + def test_members_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.members == {} def test_presences(self, model): assert model.presences is model.app.cache.get_presences_view_for_guild.return_value model.app.cache.get_presences_view_for_guild.assert_called_once_with(123) - def test_presences_when_no_cache_trait(self, model): - model.app = object() + def test_presences_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.presences == {} def test_roles(self, model): assert model.roles is model.app.cache.get_roles_view_for_guild.return_value model.app.cache.get_roles_view_for_guild.assert_called_once_with(123) - def test_roles_when_no_cache_trait(self, model): - model.app = object() + def test_roles_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.roles == {} def test_voice_states(self, model): assert model.voice_states is model.app.cache.get_voice_states_view_for_guild.return_value model.app.cache.get_voice_states_view_for_guild.assert_called_once_with(123) - def test_voice_states_when_no_cache_trait(self, model): - model.app = object() + def test_voice_states_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.voice_states == {} def test_get_channel(self, model): assert model.get_channel(456) is model.app.cache.get_guild_channel.return_value model.app.cache.get_guild_channel.assert_called_once_with(456) - def test_get_channel_when_no_cache_trait(self, model): - model.app = object() + def test_get_channel_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_channel(456) is None def test_get_emoji(self, model): assert model.get_emoji(456) is model.app.cache.get_emoji.return_value model.app.cache.get_emoji.assert_called_once_with(456) - def test_get_emoji_when_no_cache_trait(self, model): - model.app = object() + def test_get_emoji_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_emoji(456) is None def test_get_member(self, model): assert model.get_member(456) is model.app.cache.get_member.return_value model.app.cache.get_member.assert_called_once_with(123, 456) - def test_get_member_when_no_cache_trait(self, model): - model.app = object() + def test_get_member_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_member(456) is None def test_get_presence(self, model): assert model.get_presence(456) is model.app.cache.get_presence.return_value model.app.cache.get_presence.assert_called_once_with(123, 456) - def test_get_presence_when_no_cache_trait(self, model): - model.app = object() + def test_get_presence_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_presence(456) is None def test_get_role(self, model): assert model.get_role(456) is model.app.cache.get_role.return_value model.app.cache.get_role.assert_called_once_with(456) - def test_get_role_when_no_cache_trait(self, model): - model.app = object() + def test_get_role_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_role(456) is None def test_get_voice_state(self, model): assert model.get_voice_state(456) is model.app.cache.get_voice_state.return_value model.app.cache.get_voice_state.assert_called_once_with(123, 456) - def test_get_voice_state_when_no_cache_trait(self, model): - model.app = object() + def test_get_voice_state_when_no_cache_trait(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_voice_state(456) is None - def test_get_my_member_when_not_shardaware(self, model): - model.app = object() + def test_get_my_member_when_not_shardaware(self): + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=None)() assert model.get_my_member() is None def test_get_my_member_when_no_me(self, model): - model.app.me = None + model = hikari_test_helpers.mock_class_namespace(guilds.GatewayGuild, init_=False, app=mock.Mock(me=None))() assert model.get_my_member() is None def test_get_my_member(self, model): diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 7e06568369..c19624b092 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -31,6 +31,7 @@ from hikari import urls from hikari import users from hikari.internal import routes +from tests.hikari import hikari_test_helpers class TestAttachment: @@ -71,8 +72,10 @@ def test_cover_image_url(self, message_application): with mock.patch.object(messages.MessageApplication, "make_cover_image_url") as mock_cover_image: assert message_application.cover_image_url is mock_cover_image() - def test_make_cover_image_url_when_hash_is_none(self, message_application): - message_application.cover_image_hash = None + def test_make_cover_image_url_when_hash_is_none(self): + message_application = hikari_test_helpers.mock_class_namespace( + messages.MessageApplication, init_=False, cover_image_hash=None + )() assert message_application.make_cover_image_url() is None @@ -90,7 +93,7 @@ def test_make_cover_image_url_when_hash_is_not_none(self, message_application): @pytest.fixture() def message(): return messages.Message( - app=None, + app=mock.Mock(rest=mock.AsyncMock()), id=snowflakes.Snowflake(1234), channel_id=snowflakes.Snowflake(5678), guild_id=snowflakes.Snowflake(910112), @@ -125,25 +128,18 @@ def message(): class TestMessage: def test_make_link_when_guild_is_not_none(self, message): - message.id = 789 - message.channel_id = 456 - assert message.make_link(123) == "https://discord.com/channels/123/456/789" + assert message.make_link(123) == "https://discord.com/channels/123/5678/1234" def test_make_link_when_guild_is_none(self, message): - message.app = mock.Mock() - message.id = 789 - message.channel_id = 456 - assert message.make_link(None) == "https://discord.com/channels/@me/456/789" + assert message.make_link(None) == "https://discord.com/channels/@me/5678/1234" def test_guild_id_when_guild_is_not_none(self, message): - message._guild_id = 123 + assert message.guild_id == 910112 - assert message.guild_id == 123 - - def test_guild_id_when_guild_is_none(self, message): - message.app = mock.Mock() - message._guild_id = None - message.channel_id = 890 + def test_guild_id_when_guild_is_none(self): + message = hikari_test_helpers.mock_class_namespace( + messages.Message, init_=False, _guild_id=None, channel_id=890, app=mock.Mock() + )() message.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(guild_id=456)) assert message.guild_id == 456 @@ -154,15 +150,10 @@ def test_guild_id_when_guild_is_none(self, message): @pytest.mark.asyncio() class TestAsyncMessage: async def test_fetch_channel(self, message): - message.app = mock.AsyncMock() - message.channel_id = 123 await message.fetch_channel() - message.app.rest.fetch_channel.assert_awaited_once_with(123) + message.app.rest.fetch_channel.assert_awaited_once_with(5678) async def test_edit(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 embed = object() attachment = object() roles = [object()] @@ -179,8 +170,8 @@ async def test_edit(self, message): flags=messages.MessageFlag.URGENT, ) message.app.rest.edit_message.assert_awaited_once_with( - message=123, - channel=456, + message=1234, + channel=5678, content="test content", embed=embed, attachment=attachment, @@ -194,9 +185,6 @@ async def test_edit(self, message): ) async def test_respond(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 embed = object() roles = [object()] attachment = object() @@ -216,7 +204,7 @@ async def test_respond(self, message): mentions_reply=True, ) message.app.rest.create_message.assert_awaited_once_with( - channel=456, + channel=5678, content="test content", embed=embed, attachment=attachment, @@ -231,9 +219,6 @@ async def test_respond(self, message): ) async def test_respond_when_reply_is_True(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 embed = object() roles = [object()] attachment = object() @@ -252,7 +237,7 @@ async def test_respond_when_reply_is_True(self, message): mentions_reply=True, ) message.app.rest.create_message.assert_awaited_once_with( - channel=456, + channel=5678, content="test content", embed=embed, attachment=attachment, @@ -267,44 +252,26 @@ async def test_respond_when_reply_is_True(self, message): ) async def test_delete(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 await message.delete() - message.app.rest.delete_message.assert_awaited_once_with(456, 123) + message.app.rest.delete_message.assert_awaited_once_with(5678, 1234) async def test_add_reaction(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 await message.add_reaction("👌") - message.app.rest.add_reaction.assert_awaited_once_with(channel=456, message=123, emoji="👌") + message.app.rest.add_reaction.assert_awaited_once_with(channel=5678, message=1234, emoji="👌") async def test_remove_reaction(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 await message.remove_reaction("👌") - message.app.rest.delete_my_reaction.assert_awaited_once_with(channel=456, message=123, emoji="👌") + message.app.rest.delete_my_reaction.assert_awaited_once_with(channel=5678, message=1234, emoji="👌") async def test_remove_reaction_with_user(self, message): - message.app = mock.AsyncMock() user = object() - message.id = 123 - message.channel_id = 456 await message.remove_reaction("👌", user=user) - message.app.rest.delete_reaction.assert_awaited_once_with(channel=456, message=123, emoji="👌", user=user) + message.app.rest.delete_reaction.assert_awaited_once_with(channel=5678, message=1234, emoji="👌", user=user) async def test_remove_all_reactions(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 await message.remove_all_reactions() - message.app.rest.delete_all_reactions.assert_awaited_once_with(channel=456, message=123) + message.app.rest.delete_all_reactions.assert_awaited_once_with(channel=5678, message=1234) async def test_remove_all_reactions_with_emoji(self, message): - message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 await message.remove_all_reactions("👌") - message.app.rest.delete_all_reactions_for_emoji.assert_awaited_once_with(channel=456, message=123, emoji="👌") + message.app.rest.delete_all_reactions_for_emoji.assert_awaited_once_with(channel=5678, message=1234, emoji="👌") diff --git a/tests/hikari/test_sessions.py b/tests/hikari/test_sessions.py index 031e2fc51d..8c39d05882 100644 --- a/tests/hikari/test_sessions.py +++ b/tests/hikari/test_sessions.py @@ -23,6 +23,7 @@ import datetime from hikari import sessions +from tests.hikari import hikari_test_helpers def test_SessionStartLimit_used_property(): @@ -33,12 +34,10 @@ def test_SessionStartLimit_used_property(): def test_SessionStartLimit_reset_at_property(): - obj = sessions.SessionStartLimit( - total=100, - remaining=2, + obj = hikari_test_helpers.mock_class_namespace( + sessions.SessionStartLimit, + init_=False, + _created_at=datetime.datetime(2020, 7, 22, 22, 22, 36, 988017, tzinfo=datetime.timezone.utc), reset_after=datetime.timedelta(hours=1, days=10), - max_concurrency=1, - ) - obj._created_at = datetime.datetime(2020, 7, 22, 22, 22, 36, 988017, tzinfo=datetime.timezone.utc) - + )() assert obj.reset_at == datetime.datetime(2020, 8, 1, 23, 22, 36, 988017, tzinfo=datetime.timezone.utc) diff --git a/tests/hikari/test_users.py b/tests/hikari/test_users.py index f04e081019..10ae503b63 100644 --- a/tests/hikari/test_users.py +++ b/tests/hikari/test_users.py @@ -195,8 +195,10 @@ def obj(self): def test_str_operator(self, obj): assert str(obj) == "thomm.o#8637" - def test_str_operator_when_partial(self, obj): - obj.username = undefined.UNDEFINED + def test_str_operator_when_partial(self): + obj = hikari_test_helpers.mock_class_namespace( + users.PartialUserImpl, username=undefined.UNDEFINED, id=123, rename_impl_=False, init_=False + )() assert str(obj) == "Partial user ID 123" def test_mention_property(self, obj): diff --git a/tests/hikari/test_webhooks.py b/tests/hikari/test_webhooks.py index 00363693a9..ee3677ea05 100644 --- a/tests/hikari/test_webhooks.py +++ b/tests/hikari/test_webhooks.py @@ -23,6 +23,7 @@ import pytest from hikari import webhooks +from tests.hikari import hikari_test_helpers class TestWebhook: @@ -46,8 +47,9 @@ def webhook(self): def test_str(self, webhook): assert str(webhook) == "not a webhook" - def test_str_when_name_is_None(self, webhook): - webhook.name = None + def test_str_when_name_is_None(self): + webhook = hikari_test_helpers.mock_class_namespace(webhooks.Webhook, init_=False, name=None, id=987654321)() + assert str(webhook) == "Unnamed webhook ID 987654321" @pytest.mark.asyncio @@ -63,8 +65,8 @@ async def test_fetch_message(self, webhook): webhook.app.rest.fetch_webhook_message.assert_called_once_with(987654321, token="abc123bca", message=message) @pytest.mark.asyncio - async def test_fetch_message_when_no_token(self, webhook): - webhook.token = None + async def test_fetch_message_when_no_token(self): + webhook = hikari_test_helpers.mock_class_namespace(webhooks.Webhook, init_=False, token=None)() with pytest.raises(ValueError, match=r"Cannot fetch a message using a webhook where we don't know the token"): await webhook.fetch_message(987) @@ -107,8 +109,8 @@ async def test_edit_message(self, webhook): ) @pytest.mark.asyncio - async def test_edit_message_when_no_token(self, webhook): - webhook.token = None + async def test_edit_message_when_no_token(self): + webhook = hikari_test_helpers.mock_class_namespace(webhooks.Webhook, init_=False, token=None)() with pytest.raises(ValueError, match=r"Cannot edit a message using a webhook where we don't know the token"): await webhook.edit_message(987) @@ -122,7 +124,7 @@ async def test_delete_message(self, webhook): webhook.app.rest.delete_webhook_message.assert_called_once_with(987654321, token="abc123bca", message=message) @pytest.mark.asyncio - async def test_delete_message_when_no_token(self, webhook): - webhook.token = None + async def test_delete_message_when_no_token(self): + webhook = hikari_test_helpers.mock_class_namespace(webhooks.Webhook, init_=False, token=None)() with pytest.raises(ValueError, match=r"Cannot delete a message using a webhook where we don't know the token"): assert await webhook.delete_message(987)