diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index d9bb23e8c..fee91fd31 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -5,18 +5,27 @@ - [ ] Breaking code change - [ ] Documentation change/addition - [ ] Tests change +- [ ] CI change +- [ ] Other: [Replace with a description] ## Description ## Changes - + -## Checklist +## Test Scenario(s) + + - +## Checklist + - [ ] I've formatted my code with [Black](https://black.readthedocs.io/en/stable/) +- [ ] I've added docstrings to everything I've touched - [ ] I've ensured my code works on `Python 3.10.x` -- [ ] I've tested my code +- [ ] I've ensured my code works on `Python 3.11.x` +- [ ] I've tested my changes +- [ ] I've added tests for my code - if applicable +- [ ] I've updated the documentation - if applicable diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index b2cf11966..4448f97cc 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,15 +39,15 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '3.10' # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -58,7 +58,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v1 + uses: github/codeql-action/autobuild@v2 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -72,4 +72,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 05f10e31e..bcbf8c7ea 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,13 +8,13 @@ jobs: runs-on: ubuntu-latest if: github.event.repository.fork == false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: "3.10" - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: key: ${{ github.ref }} path: .cache diff --git a/.github/workflows/ok-to-test.yml b/.github/workflows/ok-to-test.yml index 14894a437..b802cd769 100644 --- a/.github/workflows/ok-to-test.yml +++ b/.github/workflows/ok-to-test.yml @@ -25,7 +25,7 @@ jobs: private_key: ${{ secrets.PRIVATE_KEY }} - name: Slash Command Dispatch - uses: peter-evans/slash-command-dispatch@v1 + uses: peter-evans/slash-command-dispatch@v3 env: TOKEN: ${{ steps.generate_token.outputs.token }} with: diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index d7c34b9cc..b50d944cd 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -5,9 +5,9 @@ jobs: name: runner / Pre-commit actions runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python 3.10 - uses: actions/setup-python@v2.3.1 + uses: actions/setup-python@v4 with: python-version: '3.10' - name: Setup annotations diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 35acb3862..1095c6f99 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -7,9 +7,9 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '3.10' - name: Install dependencies diff --git a/.github/workflows/pytest-pr.yml b/.github/workflows/pytest-pr.yml index fd0db0894..990fb5e40 100644 --- a/.github/workflows/pytest-pr.yml +++ b/.github/workflows/pytest-pr.yml @@ -15,10 +15,20 @@ jobs: - .[speedup] - .[voice] - .[all] + python-version: + - "3.10" + - "3.11" + include: + - extras: .[all] + python-version: "3.10" + BOT_TOKEN: ${{ secrets.BOT_TOKEN }} + - extras: .[all] + python-version: "3.11" + BOT_TOKEN: ${{ secrets.BOT_TOKEN }} steps: - name: Create check run - uses: actions/github-script@v5 + uses: actions/github-script@v6 id: create-check-run env: number: ${{ github.event.client_payload.pull_request.number }} @@ -42,13 +52,13 @@ jobs: return result.id; - name: Fork based /ok-to-test checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: 'refs/pull/${{ github.event.client_payload.pull_request.number }}/merge' - - name: Set up Python 3.10 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2.3.1 with: - python-version: '3.10' + python-version: ${{ matrix.python-version }} - name: Install ffmpeg & opus run: sudo apt-get install ffmpeg libopus-dev - name: Install pytest @@ -67,14 +77,14 @@ jobs: chmod +x codecov ./codecov - name: Publish Test Report - uses: mikepenz/action-junit-report@v2 + uses: mikepenz/action-junit-report@v3 if: always() # always run even if the previous step fails with: report_paths: '**/TestResults.xml' check_name: 'Pytest Results' - name: Update check run - uses: actions/github-script@v5 + uses: actions/github-script@v6 id: update-check-run if: ${{ always() }} env: diff --git a/.github/workflows/pytest-push.yml b/.github/workflows/pytest-push.yml index 2378b2135..e1c41bd78 100644 --- a/.github/workflows/pytest-push.yml +++ b/.github/workflows/pytest-push.yml @@ -13,22 +13,35 @@ jobs: - .[voice] - .[all] - .[docs] + python-version: + - "3.10" + - "3.11" + include: + - extras: .[all] + python-version: "3.10" + RUN_TESTBOT: true + - extras: .[all] + python-version: "3.11" + RUN_TESTBOT: true steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.10 + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2.3.1 with: - python-version: '3.10' + python-version: ${{ matrix.python-version }} + cache: 'pip' - name: Install ffmpeg & opus run: sudo apt-get update && sudo apt-get install ffmpeg libopus-dev - name: Install pytest run: | + pip install wheel pip install -e ${{ matrix.extras }} pip install .[tests] - name: Run Tests env: BOT_TOKEN: ${{ secrets.BOT_TOKEN }} + RUN_TESTBOT: ${{ matrix.RUN_TESTBOT }} run: | pytest coverage xml -i @@ -38,7 +51,7 @@ jobs: chmod +x codecov ./codecov - name: Publish Test Report - uses: mikepenz/action-junit-report@v2 + uses: mikepenz/action-junit-report@v3 if: always() # always run even if the previous step fails with: report_paths: '**/TestResults.xml' diff --git a/docs/src/API Reference/ext/index.md b/docs/src/API Reference/ext/index.md index 5ddb94028..1283405cf 100644 --- a/docs/src/API Reference/ext/index.md +++ b/docs/src/API Reference/ext/index.md @@ -5,6 +5,9 @@ These files contain useful features that help you develop a bot - [Debug Extension](debug_ext) - An extension preloaded with a load of debugging utilities to help you find and fix bugs +- [Jurigged](jurigged) + - An extension to enable live code patching for faster development + - [Paginators](paginators) - An automatic message paginator to help you get a lot of information across diff --git a/docs/src/API Reference/ext/jurigged.md b/docs/src/API Reference/ext/jurigged.md new file mode 100644 index 000000000..ace945bf2 --- /dev/null +++ b/docs/src/API Reference/ext/jurigged.md @@ -0,0 +1 @@ +::: naff.ext.jurigged diff --git a/docs/src/Guides/05 Components.md b/docs/src/Guides/05 Components.md index 09dfa89d3..454caee68 100644 --- a/docs/src/Guides/05 Components.md +++ b/docs/src/Guides/05 Components.md @@ -99,20 +99,18 @@ await channel.send("Look a Button!", components=components) Sometimes there might be more than a handful options which users need to decide between. That's when a `Select` should probably be used. -Selects are very similar to Buttons. The main difference is that they need options, which you supply by passing a list of `SelectOption`. +Selects are very similar to Buttons. The main difference is that you get a list of options to choose from. + +If you want to use string options, then you use `StringSelect`. Simply pass a list of strings to `options` and you are good to go. You can also explicitly pass `SelectOptions` to control the value attribute. You can also define how many options users can choose by setting `min_values` and `max_values`. + ```python -components = Select( +from naff import StringSelectMenu, SelectOption + +components = StringSelectMenu( options=[ - SelectOption( - label="Pizza", - value="Pizza" - ), - SelectOption( - label="Egg Sandwich", - value="Egg Sandwich" - ), + "Pizza", "Pasta", "Burger", "Salad" ], placeholder="What is your favourite food?", min_values=1, @@ -124,6 +122,10 @@ await channel.send("Look a Select!", components=components) ??? note You can only have upto 25 options in a Select +Alternatively, you can use `RoleSelectMenu`, `UserSelectMenu` and `ChannelSelectMenu` to select roles, users and channels respectively. These select menus are very similar to `StringSelectMenu`, but they don't allow you to pass a list of options; it's all done behind the scenes. + +```python + For more information, please visit the API reference [here](/API Reference/models/Discord/components/#naff.models.discord.components.Select). ## Responding diff --git a/docs/src/Guides/22 Live Patching.md b/docs/src/Guides/22 Live Patching.md new file mode 100644 index 000000000..b2ff67071 --- /dev/null +++ b/docs/src/Guides/22 Live Patching.md @@ -0,0 +1,19 @@ +# Live Patching + +NAFF has a few built-in extensions that add some features, primarily for debugging. One of these extensions that you can enable separately is to add [`jurigged`](https://github.com/breuleux/jurigged) for live patching of code. + +## How to enable + +```py +bot.load_extension("naff.ext.jurigged") +``` + +That's it! The extension will handle all of the leg work, and all you'll notice is that you have more messages in your logs (depending on the log level). + +## What is jurigged? + +`jurigged` is a library written to allow code hot reloading in Python. It allows you to edit code and have it automagically be updated in your program the next time it is run. The code under the hood is extremely complicated, but the interface to use it is relatively simple. + +## How is this useful? + +NAFF takes advantage of jurigged to reload any and all commands that were edited whenever a change is made, allowing you to have more uptime with while still adding/improving features of your bot. diff --git a/docs/src/Guides/22 Voice.md b/docs/src/Guides/23 Voice.md similarity index 100% rename from docs/src/Guides/22 Voice.md rename to docs/src/Guides/23 Voice.md diff --git a/docs/src/Guides/23 Localisation.md b/docs/src/Guides/24 Localisation.md similarity index 100% rename from docs/src/Guides/23 Localisation.md rename to docs/src/Guides/24 Localisation.md diff --git a/docs/src/Guides/24 Error Tracking.md b/docs/src/Guides/25 Error Tracking.md similarity index 100% rename from docs/src/Guides/24 Error Tracking.md rename to docs/src/Guides/25 Error Tracking.md diff --git a/docs/src/Guides/98 2.x Migration.md b/docs/src/Guides/98 2.x Migration.md new file mode 100644 index 000000000..d6a74758a --- /dev/null +++ b/docs/src/Guides/98 2.x Migration.md @@ -0,0 +1,110 @@ +# 1.x -> 2.x Migration Guide +2.x Was a rewrite of various parts of Naff, and as such, there are a few breaking changes. This guide will help you migrate your code from 1.x to 2.x. + +Please note; there are other additions to 2.x, but they are not breaking changes, and as such, are not covered in this guide. + +## Misc. +- All `edit` methods are now keyword arguments only. + - The exception is `content` on message edits, which is positional. +- `context.interaciton_id` is now an `int` instead of a `str`. + +## Selects +To simplify SelectMenus, NAFF made some changes to how SelectMenus are used. +- Options can now be and *reasonable* type, be it `SelectOption`, `dict`, `iterable` or `str` +- All parameters are now keyword only, excpet for `options` which remains positional or keyword +- `Select` was renamed to `StringSelectMenu` +- New select menus were implemented to support API changes + - https://discord.com/developers/docs/interactions/message-components#select-menus + - `UserSelectMenu` + - `RoleSelectMenu` + - `MentionableSelectMenu` + - `ChannelSelectMenu` + - `ChannelSelectMenu` + +### Before +```python +from naff import Select, SelectOption +await channel.send( + "Old SelectUX", + components=Select( + options=[ + SelectOption("test1", "test1"), + SelectOption("test2", "test2"), + SelectOption("test3", "test3"), + ], + placeholder="test", + ), + ) +``` + +### After +```python +from naff import StringSelectMenu + +await channel.send( + "New SelectMenu Menu UX test", components=StringSelectMenu(["test1", "test2", "test3"], placeholder="test") + ) +``` + +## Listeners +Listeners have received a series of ease-of-use updates for both extension and bot developers alike. + +- All internal listeners now have a `is_default_listener` attribute to make it easier to differentiate between the library's listeners and user defined listeners. +- `override_default_listeners` allows you to completely override the library's listeners with your own. + - Note it might be worth looking into processors if you're doing this; as they allow acting on the raw-payloads before they're processed by the library. +- All event objects now have a shortcut to listen to them via `BaseEvent.listen(coro, Client)` +- Listeners can now be delayed until the client is ready with a `delay_until_ready` argument. + +## Events +- All event objects now have a shortcut to listen to them via `BaseEvent.listen(coro, Client)` +- New Events! + - `ComponentCompletion` - Dispatched after the library ran any component callback. + - `AutocompleteCompletion` - Dispatched after the library ran any autocomplete callback. + - `ModalCompletion` - Dispatched after the library ran any modal callback. + - `Error` - Dispatched whenever the libray encounters an unhandled exception. + - Previously this was done by overriding the `on_error` method on the client, or in extensions + - `CommandError` - Dispatched whenever a command encounters an unhandled exception. + - `ComponentError` - Dispatched whenever a component encounters an unhandled exception. + - `AutocompleteError` - Dispatched whenever an autocomplete encounters an unhandled exception. + - `ModalError` - Dispatched whenever a modal encounters an unhandled exception. + - `NewThreadCreate` - Dispatched whenever a thread is newly created + - `GuildAvailable` - Dispatched whenever a guild becomes available + - note this requires the guild cache to be enabled + - `ApplicationCommandPermissionsUpdate` - Dispatched whenever a guild's application command permissions are updated + - `VoiceUserDeafen` - Dispatched whenever a user's deafen status changes + - `VoiceUserJoin` - Dispatched whenever a user joins a voice channel + - `VoiceUserLeave` - Dispatched whenever a user leaves a voice channel + - `VoiceUserMove` - Dispatched whenever a user moves to a different voice channel + - `VoiceUserMute` - Dispatched whenever a user's mute status changes +- Event Renamed + - `Button` has been renamed to `ButtonPressed` to avoid naming conflicts +- All events with a `context` attribute have had it renamed to `ctx` for consistency + +## Client Changes +- dm commands can now be disabled completely via the `disable_dm_commands` kwarg +- `Client.interaction_tree` offers a command tree of all application commands registered to the client +- The `Client` now sanity checks the cache configuration +- The `Client` will no longer warn you if a cache timeout occurs during startup + - These are caused by T&S shadow-deleting guilds, and are not a concern +- `async_startup_tasks` are now performed as soon as the client successfully connects to the REST API + - Note this is before the gateway connection is established, use a `on_ready` listener if you need to wait for the gateway connection +- Application Command syncing is now more error tolerant + +## Extensions +- Extensions no longer require a `setup` entrypoint function. + - For complex setups, I would still advise using an entrypoint function + +## Caching +- A `NullCache` object is now used to represent a disabled cache, to use it use `create_cache(0, 0, 0)` in a client kwarg as before + - This is a very niche-use-case and most people won't need to use it +- The `Client` will log a warning if `NullCache` is used for sanity checking +- The serializer now respects `no_export` metadata when using `as_dict` + +## Forums +- Forums now utilise the new API spec instead of the private-beta API +- A new `NewThreadCreate` event is now dispatched for brand new threads +- Add various helper methods to `Forum` objects +- `create_post` now handles `str`, `int` `Tag` objects for available tags + +## Emoji +- `PartialEmoji.from_str` can now return None if no emoji is found diff --git a/docs/src/Guides/index.md b/docs/src/Guides/index.md index 3bc65db61..5b0cc19d8 100644 --- a/docs/src/Guides/index.md +++ b/docs/src/Guides/index.md @@ -106,5 +106,11 @@ These guides are meant to help you get started with the library and offer a poin Whats different between NAFF and discord.py? +- [__:material-package-up: Migration from discord.py__](98 2.x Migration.md) + + --- + + How do I migrate from 1.x to 2.x? + diff --git a/naff/api/events/__init__.py b/naff/api/events/__init__.py index a5a2231b2..e7ed6614c 100644 --- a/naff/api/events/__init__.py +++ b/naff/api/events/__init__.py @@ -1,64 +1,89 @@ from . import processors from .discord import * from .internal import * +from .base import * __all__ = ( - "RawGatewayEvent", + "ApplicationCommandPermissionsUpdate", + "AutocompleteCompletion", + "AutocompleteError", + "AutoModCreated", + "AutoModDeleted", + "AutoModExec", + "AutoModUpdated", + "BanCreate", + "BanRemove", + "BaseEvent", + "BaseVoiceEvent", + "ButtonPressed", "ChannelCreate", - "ChannelUpdate", "ChannelDelete", "ChannelPinsUpdate", - "ThreadCreate", - "ThreadUpdate", - "ThreadDelete", - "ThreadListSync", - "ThreadMemberUpdate", - "ThreadMembersUpdate", + "ChannelUpdate", + "CommandCompletion", + "CommandError", + "Component", + "ComponentCompletion", + "ComponentError", + "Connect", + "Disconnect", + "Error", + "GuildAvailable", + "GuildEmojisUpdate", + "GuildEvent", "GuildJoin", - "GuildUpdate", "GuildLeft", - "GuildUnavailable", - "BanCreate", - "BanRemove", - "GuildEmojisUpdate", - "GuildStickersUpdate", - "MemberAdd", - "MemberRemove", - "MemberUpdate", - "RoleCreate", - "RoleUpdate", - "RoleDelete", "GuildMembersChunk", + "GuildStickersUpdate", + "GuildUnavailable", + "GuildUpdate", "IntegrationCreate", - "IntegrationUpdate", "IntegrationDelete", + "IntegrationUpdate", + "InteractionCreate", "InviteCreate", "InviteDelete", + "Login", + "MemberAdd", + "MemberRemove", + "MemberUpdate", "MessageCreate", - "MessageUpdate", "MessageDelete", "MessageDeleteBulk", "MessageReactionAdd", "MessageReactionRemove", "MessageReactionRemoveAll", + "MessageUpdate", + "ModalCompletion", + "ModalError", + "NewThreadCreate", "PresenceUpdate", + "RawGatewayEvent", + "Ready", + "Resume", + "RoleCreate", + "RoleDelete", + "RoleUpdate", + "Select", + "ShardConnect", + "ShardDisconnect", "StageInstanceCreate", "StageInstanceDelete", "StageInstanceUpdate", + "Startup", + "ThreadCreate", + "ThreadDelete", + "ThreadListSync", + "ThreadMembersUpdate", + "ThreadMemberUpdate", + "ThreadUpdate", "TypingStart", - "WebhooksUpdate", - "InteractionCreate", "VoiceStateUpdate", - "BaseEvent", - "GuildEvent", - "Login", - "Connect", - "Resume", - "Disconnect", - "Startup", - "Ready", + "VoiceUserDeafen", + "VoiceUserJoin", + "VoiceUserLeave", + "VoiceUserMove", + "VoiceUserMute", + "WebhooksUpdate", "WebsocketReady", - "Component", - "Button", - "Select", ) diff --git a/naff/api/events/base.py b/naff/api/events/base.py new file mode 100644 index 000000000..559d00f8f --- /dev/null +++ b/naff/api/events/base.py @@ -0,0 +1,85 @@ +import re +from typing import TYPE_CHECKING, Callable, Coroutine + +import attrs + +import naff.models as models +from naff.client.const import MISSING +from naff.client.utils.attr_utils import docs +from naff.models.discord.snowflake import to_snowflake + +if TYPE_CHECKING: + from naff import Client + from naff.models.discord.snowflake import Snowflake_Type + from naff.models.discord.guild import Guild + +__all__ = ("BaseEvent", "GuildEvent", "RawGatewayEvent") + +_event_reg = re.compile("(? str: + """The name of the event, defaults to the class name if not overridden.""" + name = self.override_name or self.__class__.__name__ + return _event_reg.sub("_", name).lower() + + @classmethod + def listen(cls, coro: Callable[..., Coroutine], client: "Client") -> "models.Listener": + """ + A shortcut for creating a listener for this event + + Args: + coro: The coroutine to call when the event is triggered. + client: The client instance to listen to. + + + ??? Hint "Example Usage:" + ```python + class SomeClass: + def __init__(self, bot: Client): + Ready.listen(self.some_func, bot) + + async def some_func(self, event): + print(f"{event.resolved_name} triggered") + ``` + Returns: + A listener object. + """ + listener = models.Listener.create(cls().resolved_name)(coro) + client.add_listener(listener) + return listener + + +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=False) +class GuildEvent(BaseEvent): + """A base event that adds guild_id.""" + + guild_id: "Snowflake_Type" = attrs.field(repr=False, metadata=docs("The ID of the guild"), converter=to_snowflake) + + @property + def guild(self) -> "Guild": + """Guild related to event""" + return self.bot.cache.get_guild(self.guild_id) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class RawGatewayEvent(BaseEvent): + """ + An event dispatched from the gateway. + + Holds the raw dict that the gateway dispatches + + """ + + data: dict = attrs.field(repr=False, factory=dict) + """Raw Data from the gateway""" diff --git a/naff/api/events/discord.py b/naff/api/events/discord.py index facb1014b..3b143f277 100644 --- a/naff/api/events/discord.py +++ b/naff/api/events/discord.py @@ -23,18 +23,22 @@ def on_guild_join(event): from typing import TYPE_CHECKING, List, Sequence, Union, Optional +import attrs + import naff.models -from naff.client.const import MISSING, Absent -from naff.client.utils.attr_utils import define, field, docs -from .internal import BaseEvent, GuildEvent +from naff.api.events.base import GuildEvent, BaseEvent +from naff.client.const import Absent +from naff.client.utils.attr_utils import docs __all__ = ( - "BanCreate", - "BanRemove", - "AutoModExec", + "ApplicationCommandPermissionsUpdate", "AutoModCreated", - "AutoModUpdated", "AutoModDeleted", + "AutoModExec", + "AutoModUpdated", + "BanCreate", + "BanRemove", + "BaseVoiceEvent", "ChannelCreate", "ChannelDelete", "ChannelPinsUpdate", @@ -44,6 +48,7 @@ def on_guild_join(event): "GuildLeft", "GuildMembersChunk", "GuildStickersUpdate", + "GuildAvailable", "GuildUnavailable", "GuildUpdate", "IntegrationCreate", @@ -62,9 +67,8 @@ def on_guild_join(event): "MessageReactionRemove", "MessageReactionRemoveAll", "MessageUpdate", - "ModalResponse", + "NewThreadCreate", "PresenceUpdate", - "RawGatewayEvent", "RoleCreate", "RoleDelete", "RoleUpdate", @@ -74,18 +78,23 @@ def on_guild_join(event): "ThreadCreate", "ThreadDelete", "ThreadListSync", - "ThreadMemberUpdate", "ThreadMembersUpdate", + "ThreadMemberUpdate", "ThreadUpdate", "TypingStart", "VoiceStateUpdate", + "VoiceUserDeafen", + "VoiceUserJoin", + "VoiceUserLeave", + "VoiceUserMove", + "VoiceUserMute", "WebhooksUpdate", ) if TYPE_CHECKING: from naff.models.discord.guild import Guild, GuildIntegration - from naff.models.discord.channel import BaseChannel, TYPE_THREAD_CHANNEL + from naff.models.discord.channel import BaseChannel, TYPE_THREAD_CHANNEL, VoiceChannel from naff.models.discord.message import Message from naff.models.discord.timestamp import Timestamp from naff.models.discord.user import Member, User, BaseUser @@ -96,114 +105,131 @@ def on_guild_join(event): from naff.models.discord.sticker import Sticker from naff.models.discord.voice_state import VoiceState from naff.models.discord.stage_instance import StageInstance - from naff.models.naff.context import ModalContext from naff.models.discord.auto_mod import AutoModerationAction, AutoModRule from naff.models.discord.reaction import Reaction + from naff.models.discord.app_perms import ApplicationCommandPermission -@define(kw_only=False) -class RawGatewayEvent(BaseEvent): - """ - An event dispatched from the gateway. - - Holds the raw dict that the gateway dispatches - - """ - - data: dict = field(factory=dict) - """Raw Data from the gateway""" - - -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class AutoModExec(BaseEvent): """Dispatched when an auto modation action is executed""" - execution: "AutoModerationAction" = field(metadata=docs("The executed auto mod action")) - channel: "BaseChannel" = field(metadata=docs("The channel the action was executed in")) - guild: "Guild" = field(metadata=docs("The guild the action was executed in")) + execution: "AutoModerationAction" = attrs.field(repr=False, metadata=docs("The executed auto mod action")) + channel: "BaseChannel" = attrs.field(repr=False, metadata=docs("The channel the action was executed in")) + guild: "Guild" = attrs.field(repr=False, metadata=docs("The guild the action was executed in")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class AutoModCreated(BaseEvent): - guild: "Guild" = field(metadata=docs("The guild the rule was modified in")) - rule: "AutoModRule" = field(metadata=docs("The rule that was modified")) + guild: "Guild" = attrs.field(repr=False, metadata=docs("The guild the rule was modified in")) + rule: "AutoModRule" = attrs.field(repr=False, metadata=docs("The rule that was modified")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class AutoModUpdated(AutoModCreated): """Dispatched when an auto mod rule is modified""" ... -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class AutoModDeleted(AutoModCreated): """Dispatched when an auto mod rule is deleted""" ... -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class ApplicationCommandPermissionsUpdate(BaseEvent): + guild_id: "Snowflake_Type" = attrs.field( + repr=False, metadata=docs("The guild the command permissions were updated in") + ) + application_id: "Snowflake_Type" = attrs.field( + repr=False, metadata=docs("The application the command permissions were updated for") + ) + permissions: List["ApplicationCommandPermission"] = attrs.field( + repr=False, factory=list, metadata=docs("The updated permissions") + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ChannelCreate(BaseEvent): """Dispatched when a channel is created.""" - channel: "BaseChannel" = field(metadata=docs("The channel this event is dispatched from")) + channel: "BaseChannel" = attrs.field(repr=False, metadata=docs("The channel this event is dispatched from")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ChannelUpdate(BaseEvent): """Dispatched when a channel is updated.""" - before: "BaseChannel" = field() + before: "BaseChannel" = attrs.field( + repr=False, + ) """Channel before this event. MISSING if it was not cached before""" - after: "BaseChannel" = field() + after: "BaseChannel" = attrs.field( + repr=False, + ) """Channel after this event""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ChannelDelete(ChannelCreate): """Dispatched when a channel is deleted.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ChannelPinsUpdate(ChannelCreate): """Dispatched when a channel's pins are updated.""" - last_pin_timestamp: "Timestamp" = field() + last_pin_timestamp: "Timestamp" = attrs.field( + repr=False, + ) """The time at which the most recent pinned message was pinned""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadCreate(BaseEvent): - """Dispatched when a thread is created.""" + """Dispatched when a thread is created, or a thread is new to the client""" + + thread: "TYPE_THREAD_CHANNEL" = attrs.field(repr=False, metadata=docs("The thread this event is dispatched from")) - thread: "TYPE_THREAD_CHANNEL" = field(metadata=docs("The thread this event is dispatched from")) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class NewThreadCreate(ThreadCreate): + """Dispatched when a thread is newly created.""" -@define(kw_only=False) + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadUpdate(ThreadCreate): """Dispatched when a thread is updated.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadDelete(ThreadCreate): """Dispatched when a thread is deleted.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadListSync(BaseEvent): """Dispatched when gaining access to a channel, contains all active threads in that channel.""" - channel_ids: Sequence["Snowflake_Type"] = field() + channel_ids: Sequence["Snowflake_Type"] = attrs.field( + repr=False, + ) """The parent channel ids whose threads are being synced. If omitted, then threads were synced for the entire guild. This array may contain channel_ids that have no active threads as well, so you know to clear that data.""" - threads: List["BaseChannel"] = field() + threads: List["BaseChannel"] = attrs.field( + repr=False, + ) """all active threads in the given channels that the current user can access""" - members: List["Member"] = field() + members: List["Member"] = attrs.field( + repr=False, + ) """all thread member objects from the synced threads for the current user, indicating which threads the current user has been added to""" # todo implementation missing -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadMemberUpdate(ThreadCreate): """ Dispatched when the thread member object for the current user is updated. @@ -214,26 +240,30 @@ class ThreadMemberUpdate(ThreadCreate): """ - member: "Member" = field() + member: "Member" = attrs.field( + repr=False, + ) """The member who was added""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ThreadMembersUpdate(BaseEvent): """Dispatched when anyone is added or removed from a thread.""" - id: "Snowflake_Type" = field() + id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The ID of the thread""" - member_count: int = field(default=50) + member_count: int = attrs.field(repr=False, default=50) """the approximate number of members in the thread, capped at 50""" - added_members: List["Member"] = field(factory=list) + added_members: List["Member"] = attrs.field(repr=False, factory=list) """Users added to the thread""" - removed_member_ids: List["Snowflake_Type"] = field(factory=list) + removed_member_ids: List["Snowflake_Type"] = attrs.field(repr=False, factory=list) """Users removed from the thread""" -@define(kw_only=False) -class GuildJoin(BaseEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildJoin(GuildEvent): """ Dispatched when a guild is joined, created, or becomes available. @@ -242,122 +272,139 @@ class GuildJoin(BaseEvent): """ - guild: "Guild" = field() - """The guild that was created""" - -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class GuildUpdate(BaseEvent): """Dispatched when a guild is updated.""" - before: "Guild" = field() + before: "Guild" = attrs.field( + repr=False, + ) """Guild before this event""" - after: "Guild" = field() + after: "Guild" = attrs.field( + repr=False, + ) """Guild after this event""" -@define(kw_only=False) -class GuildLeft(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildLeft(BaseEvent): """Dispatched when a guild is left.""" - guild: Optional["Guild"] = field(default=MISSING) - """The guild, if it was cached""" + guild: "Guild" = attrs.field(repr=True) + """The guild this event is dispatched from""" -@define(kw_only=False) -class GuildUnavailable(BaseEvent, GuildEvent): - """Dispatched when a guild is not available.""" +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildAvailable(GuildEvent): + """Dispatched when a guild becomes available.""" - guild: Optional["Guild"] = field(default=MISSING) - """The guild, if it was cached""" +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildUnavailable(GuildEvent): + """Dispatched when a guild is not available.""" -@define(kw_only=False) -class BanCreate(BaseEvent, GuildEvent): + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class BanCreate(GuildEvent): """Dispatched when someone was banned from a guild.""" - user: "BaseUser" = field(metadata=docs("The user")) + user: "BaseUser" = attrs.field(repr=False, metadata=docs("The user")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class BanRemove(BanCreate): """Dispatched when a users ban is removed.""" -@define(kw_only=False) -class GuildEmojisUpdate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildEmojisUpdate(GuildEvent): """Dispatched when a guild's emojis are updated.""" - before: List["CustomEmoji"] = field(factory=list) + before: List["CustomEmoji"] = attrs.field(repr=False, factory=list) """List of emoji before this event. Only includes emojis that were cached. To enable the emoji cache (and this field), start your bot with `Client(enable_emoji_cache=True)`""" - after: List["CustomEmoji"] = field(factory=list) + after: List["CustomEmoji"] = attrs.field(repr=False, factory=list) """List of emoji after this event""" -@define(kw_only=False) -class GuildStickersUpdate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildStickersUpdate(GuildEvent): """Dispatched when a guild's stickers are updated.""" - stickers: List["Sticker"] = field(factory=list) + stickers: List["Sticker"] = attrs.field(repr=False, factory=list) """List of stickers from after this event""" -@define(kw_only=False) -class MemberAdd(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MemberAdd(GuildEvent): """Dispatched when a member is added to a guild.""" - member: "Member" = field(metadata=docs("The member who was added")) + member: "Member" = attrs.field(repr=False, metadata=docs("The member who was added")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MemberRemove(MemberAdd): """Dispatched when a member is removed from a guild.""" - member: Union["Member", "User"] = field( - metadata=docs("The member who was added, can be user if the member is not cached") + member: Union["Member", "User"] = attrs.field( + repr=False, metadata=docs("The member who was added, can be user if the member is not cached") ) -@define(kw_only=False) -class MemberUpdate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MemberUpdate(GuildEvent): """Dispatched when a member is updated.""" - before: "Member" = field() + before: "Member" = attrs.field( + repr=False, + ) """The state of the member before this event""" - after: "Member" = field() + after: "Member" = attrs.field( + repr=False, + ) """The state of the member after this event""" -@define(kw_only=False) -class RoleCreate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class RoleCreate(GuildEvent): """Dispatched when a role is created.""" - role: "Role" = field() + role: "Role" = attrs.field( + repr=False, + ) """The created role""" -@define(kw_only=False) -class RoleUpdate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class RoleUpdate(GuildEvent): """Dispatched when a role is updated.""" - before: Absent["Role"] = field() + before: Absent["Role"] = attrs.field( + repr=False, + ) """The role before this event""" - after: "Role" = field() + after: "Role" = attrs.field( + repr=False, + ) """The role after this event""" -@define(kw_only=False) -class RoleDelete(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class RoleDelete(GuildEvent): """Dispatched when a guild role is deleted.""" - id: "Snowflake_Type" = field() + id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The ID of the deleted role""" - role: Absent["Role"] = field() + role: Absent["Role"] = attrs.field( + repr=False, + ) """The deleted role""" -@define(kw_only=False) -class GuildMembersChunk(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class GuildMembersChunk(GuildEvent): """ Sent in response to Guild Request Members. @@ -366,96 +413,120 @@ class GuildMembersChunk(BaseEvent, GuildEvent): """ - chunk_index: int = field() + chunk_index: int = attrs.field( + repr=False, + ) """The chunk index in the expected chunks for this response (0 <= chunk_index < chunk_count)""" - chunk_count: int = field() + chunk_count: int = attrs.field( + repr=False, + ) """the total number of expected chunks for this response""" - presences: List = field() + presences: List = attrs.field( + repr=False, + ) """if passing true to `REQUEST_GUILD_MEMBERS`, presences of the returned members will be here""" - nonce: str = field() + nonce: str = attrs.field( + repr=False, + ) """The nonce used in the request, if any""" - members: List["Member"] = field(factory=list) + members: List["Member"] = attrs.field(repr=False, factory=list) """A list of members""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class IntegrationCreate(BaseEvent): """Dispatched when a guild integration is created.""" - integration: "GuildIntegration" = field() + integration: "GuildIntegration" = attrs.field( + repr=False, + ) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class IntegrationUpdate(IntegrationCreate): """Dispatched when a guild integration is updated.""" -@define(kw_only=False) -class IntegrationDelete(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class IntegrationDelete(GuildEvent): """Dispatched when a guild integration is deleted.""" - id: "Snowflake_Type" = field() + id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The ID of the integration""" - application_id: "Snowflake_Type" = field(default=None) + application_id: "Snowflake_Type" = attrs.field(repr=False, default=None) """The ID of the bot/application for this integration""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class InviteCreate(BaseEvent): """Dispatched when a guild invite is created.""" - invite: naff.models.Invite = field() + invite: naff.models.Invite = attrs.field( + repr=False, + ) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class InviteDelete(InviteCreate): """Dispatched when an invite is deleted.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MessageCreate(BaseEvent): """Dispatched when a message is created.""" - message: "Message" = field() + message: "Message" = attrs.field( + repr=False, + ) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MessageUpdate(BaseEvent): """Dispatched when a message is edited.""" - before: "Message" = field() + before: "Message" = attrs.field( + repr=False, + ) """The message before this event was created""" - after: "Message" = field() + after: "Message" = attrs.field( + repr=False, + ) """The message after this event was created""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MessageDelete(BaseEvent): """Dispatched when a message is deleted.""" - message: "Message" = field() + message: "Message" = attrs.field( + repr=False, + ) -@define(kw_only=False) -class MessageDeleteBulk(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessageDeleteBulk(GuildEvent): """Dispatched when multiple messages are deleted at once.""" - channel_id: "Snowflake_Type" = field() + channel_id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The ID of the channel these were deleted in""" - ids: List["Snowflake_Type"] = field(factory=list) + ids: List["Snowflake_Type"] = attrs.field(repr=False, factory=list) """A list of message snowflakes""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MessageReactionAdd(BaseEvent): """Dispatched when a reaction is added to a message.""" - message: "Message" = field(metadata=docs("The message that was reacted to")) - emoji: "PartialEmoji" = field(metadata=docs("The emoji that was added to the message")) - author: Union["Member", "User"] = field(metadata=docs("The user who added the reaction")) + message: "Message" = attrs.field(repr=False, metadata=docs("The message that was reacted to")) + emoji: "PartialEmoji" = attrs.field(repr=False, metadata=docs("The emoji that was added to the message")) + author: Union["Member", "User"] = attrs.field(repr=False, metadata=docs("The user who added the reaction")) # reaction can be None when the message is not in the cache, and it was the last reaction, and it was deleted in the event - reaction: Optional["Reaction"] = field( - default=None, metadata=docs("The reaction object corresponding to the emoji") + reaction: Optional["Reaction"] = attrs.field( + repr=False, default=None, metadata=docs("The reaction object corresponding to the emoji") ) @property @@ -466,95 +537,205 @@ def reaction_count(self) -> int: return self.reaction.count -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class MessageReactionRemove(MessageReactionAdd): """Dispatched when a reaction is removed.""" -@define(kw_only=False) -class MessageReactionRemoveAll(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessageReactionRemoveAll(GuildEvent): """Dispatched when all reactions are removed from a message.""" - message: "Message" = field() + message: "Message" = attrs.field( + repr=False, + ) """The message that was reacted to""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class PresenceUpdate(BaseEvent): """A user's presence has changed.""" - user: "User" = field() + user: "User" = attrs.field( + repr=False, + ) """The user in question""" - status: str = field() + status: str = attrs.field( + repr=False, + ) """'Either `idle`, `dnd`, `online`, or `offline`'""" - activities: List["Activity"] = field() + activities: List["Activity"] = attrs.field( + repr=False, + ) """The users current activities""" - client_status: dict = field() + client_status: dict = attrs.field( + repr=False, + ) """What platform the user is reported as being on""" - guild_id: "Snowflake_Type" = field() + guild_id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The guild this presence update was dispatched from""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class StageInstanceCreate(BaseEvent): """Dispatched when a stage instance is created.""" - stage_instance: "StageInstance" = field(metadata=docs("The stage instance")) + stage_instance: "StageInstance" = attrs.field(repr=False, metadata=docs("The stage instance")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class StageInstanceDelete(StageInstanceCreate): """Dispatched when a stage instance is deleted.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class StageInstanceUpdate(StageInstanceCreate): """Dispatched when a stage instance is updated.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class TypingStart(BaseEvent): """Dispatched when a user starts typing.""" - author: Union["User", "Member"] = field() + author: Union["User", "Member"] = attrs.field( + repr=False, + ) """The user who started typing""" - channel: "BaseChannel" = field() + channel: "BaseChannel" = attrs.field( + repr=False, + ) """The channel typing is in""" - guild: "Guild" = field() + guild: "Guild" = attrs.field( + repr=False, + ) """The ID of the guild this typing is in""" - timestamp: "Timestamp" = field() + timestamp: "Timestamp" = attrs.field( + repr=False, + ) """unix time (in seconds) of when the user started typing""" -@define(kw_only=False) -class WebhooksUpdate(BaseEvent, GuildEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class WebhooksUpdate(GuildEvent): """Dispatched when a guild channel webhook is created, updated, or deleted.""" # Discord doesnt sent the webhook object for this event, for some reason - channel_id: "Snowflake_Type" = field() + channel_id: "Snowflake_Type" = attrs.field( + repr=False, + ) """The ID of the webhook was updated""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class InteractionCreate(BaseEvent): """Dispatched when a user uses an Application Command.""" - interaction: dict = field() - - -@define(kw_only=False) -class ModalResponse(BaseEvent): - """Dispatched when a modal receives a response""" - - context: "ModalContext" = field() - """The context data of the modal""" + interaction: dict = attrs.field( + repr=False, + ) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class VoiceStateUpdate(BaseEvent): - """Dispatched when a user joins/leaves/moves voice channels.""" + """Dispatched when a user's voice state changes.""" - before: Optional["VoiceState"] = field() + before: Optional["VoiceState"] = attrs.field( + repr=False, + ) """The voice state before this event was created or None if the user was not in a voice channel""" - after: Optional["VoiceState"] = field() + after: Optional["VoiceState"] = attrs.field( + repr=False, + ) """The voice state after this event was created or None if the user is no longer in a voice channel""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class BaseVoiceEvent(BaseEvent): + state: "VoiceState" = attrs.field( + repr=False, + ) + """The current voice state of the user""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class VoiceUserMove(BaseVoiceEvent): + """Dispatched when a user moves voice channels.""" + + author: Union["User", "Member"] = attrs.field( + repr=False, + ) + + previous_channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The previous voice channel the user was in""" + new_channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The new voice channel the user is in""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class VoiceUserMute(BaseVoiceEvent): + """Dispatched when a user is muted or unmuted.""" + + author: Union["User", "Member"] = attrs.field( + repr=False, + ) + """The user who was muted or unmuted""" + channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The voice channel the user was muted or unmuted in""" + mute: bool = attrs.field( + repr=False, + ) + """The new mute state of the user""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class VoiceUserDeafen(BaseVoiceEvent): + """Dispatched when a user is deafened or undeafened.""" + + author: Union["User", "Member"] = attrs.field( + repr=False, + ) + """The user who was deafened or undeafened""" + channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The voice channel the user was deafened or undeafened in""" + deaf: bool = attrs.field( + repr=False, + ) + """The new deaf state of the user""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class VoiceUserJoin(BaseVoiceEvent): + """Dispatched when a user joins a voice channel.""" + + author: Union["User", "Member"] = attrs.field( + repr=False, + ) + """The user who joined the voice channel""" + channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The voice channel the user joined""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class VoiceUserLeave(BaseVoiceEvent): + """Dispatched when a user leaves a voice channel.""" + + author: Union["User", "Member"] = attrs.field( + repr=False, + ) + """The user who left the voice channel""" + channel: "VoiceChannel" = attrs.field( + repr=False, + ) + """The voice channel the user left""" diff --git a/naff/api/events/internal.py b/naff/api/events/internal.py index 687b74e98..5ce13ea89 100644 --- a/naff/api/events/internal.py +++ b/naff/api/events/internal.py @@ -21,130 +21,87 @@ def on_guild_join(event): """ import re -from typing import TYPE_CHECKING, Any, Optional, Callable, Coroutine +from typing import Any, Optional, TYPE_CHECKING -from naff.client.const import MISSING -from naff.models.discord.snowflake import to_snowflake -from naff.client.utils.attr_utils import define, field, docs -import naff.models as models +import attrs + +from naff.api.events.base import BaseEvent, RawGatewayEvent +from naff.client.utils.attr_utils import docs __all__ = ( - "BaseEvent", - "Button", + "ButtonPressed", "Component", "Connect", "Disconnect", "Error", "ShardConnect", "ShardDisconnect", - "GuildEvent", "Login", "Ready", "Resume", "Select", "Startup", "WebsocketReady", + "CommandError", + "ComponentError", + "AutocompleteError", + "ModalError", + "CommandCompletion", + "ComponentCompletion", + "AutocompleteCompletion", + "ModalCompletion", ) if TYPE_CHECKING: - from naff import Client - from naff.models.naff.context import ComponentContext, Context - from naff.models.discord.snowflake import Snowflake_Type - from naff.models.discord.guild import Guild + from naff.models.naff.context import ( + ComponentContext, + Context, + AutocompleteContext, + ModalContext, + InteractionContext, + PrefixedContext, + HybridContext, + ) _event_reg = re.compile("(? str: - """The name of the event, defaults to the class name if not overridden.""" - name = self.override_name or self.__class__.__name__ - return _event_reg.sub("_", name).lower() - - @classmethod - def listen(cls, coro: Callable[..., Coroutine], client: "Client") -> "models.Listener": - """ - A shortcut for creating a listener for this event - - Args: - coro: The coroutine to call when the event is triggered. - client: The client instance to listen to. - - - ??? Hint "Example Usage:" - ```python - class SomeClass: - def __init__(self, bot: Client): - Ready.listen(self.some_func, bot) - - async def some_func(self, event): - print(f"{event.resolved_name} triggered") - ``` - Returns: - A listener object. - """ - listener = models.Listener.create(cls().resolved_name)(coro) - client.add_listener(listener) - return listener - - -@define(slots=False, kw_only=False) -class GuildEvent: - """A base event that adds guild_id.""" - - guild_id: "Snowflake_Type" = field(metadata=docs("The ID of the guild"), converter=to_snowflake) - - @property - def guild(self) -> "Guild": - """Guild related to event""" - return self.bot.cache.get_guild(self.guild_id) - - -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Login(BaseEvent): """The bot has just logged in.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Connect(BaseEvent): """The bot is now connected to the discord Gateway.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Resume(BaseEvent): """The bot has resumed its connection to the discord Gateway.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Disconnect(BaseEvent): """The bot has just disconnected.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ShardConnect(Connect): """A shard just connected to the discord Gateway.""" - shard_id: int = field(metadata=docs("The ID of the shard")) + shard_id: int = attrs.field(repr=False, metadata=docs("The ID of the shard")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ShardDisconnect(Disconnect): """A shard just disconnected.""" - shard_id: int = field(metadata=docs("The ID of the shard")) + shard_id: int = attrs.field(repr=False, metadata=docs("The ID of the shard")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Startup(BaseEvent): """ The client is now ready for the first time. @@ -155,7 +112,7 @@ class Startup(BaseEvent): """ -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Ready(BaseEvent): """ The client is now ready. @@ -167,36 +124,100 @@ class Ready(BaseEvent): """ -@define(kw_only=False) -class WebsocketReady(BaseEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class WebsocketReady(RawGatewayEvent): """The gateway has reported that it is ready.""" - data: dict = field(metadata=docs("The data from the ready event")) + data: dict = attrs.field(repr=False, metadata=docs("The data from the ready event")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Component(BaseEvent): """Dispatched when a user uses a Component.""" - context: "ComponentContext" = field(metadata=docs("The context of the interaction")) + ctx: "ComponentContext" = attrs.field(repr=False, metadata=docs("The context of the interaction")) -@define(kw_only=False) -class Button(Component): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class ButtonPressed(Component): """Dispatched when a user uses a Button.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Select(Component): """Dispatched when a user uses a Select.""" -@define(kw_only=False) -class Error(BaseEvent): +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class CommandCompletion(BaseEvent): + """Dispatched after the library ran any command callback.""" + + ctx: "InteractionContext | PrefixedContext | HybridContext" = attrs.field( + repr=False, metadata=docs("The command context") + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class ComponentCompletion(BaseEvent): + """Dispatched after the library ran any component callback.""" + + ctx: "ComponentContext" = attrs.field(repr=False, metadata=docs("The component context")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class AutocompleteCompletion(BaseEvent): + """Dispatched after the library ran any autocomplete callback.""" + + ctx: "AutocompleteContext" = attrs.field(repr=False, metadata=docs("The autocomplete context")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class ModalCompletion(BaseEvent): + """Dispatched after the library ran any modal callback.""" + + ctx: "ModalContext" = attrs.field(repr=False, metadata=docs("The modal context")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class _Error(BaseEvent): + error: Exception = attrs.field(repr=False, metadata=docs("The error that was encountered")) + args: tuple[Any] = attrs.field(repr=False, factory=tuple) + kwargs: dict[str, Any] = attrs.field(repr=False, factory=dict) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class Error(_Error): """Dispatched when the library encounters an error.""" - source: str = field(metadata=docs("The source of the error")) - error: Exception = field(metadata=docs("The error that was encountered")) - args: tuple[Any] = field(factory=tuple) - kwargs: dict[str, Any] = field(factory=dict) - ctx: Optional["Context"] = field(default=None, metadata=docs("The Context, if one was active")) + source: str = attrs.field(repr=False, metadata=docs("The source of the error")) + ctx: Optional["Context"] = attrs.field(repr=False, default=None, metadata=docs("The Context, if one was active")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class CommandError(_Error): + """Dispatched when the library encounters an error in a command.""" + + ctx: "InteractionContext | PrefixedContext | HybridContext" = attrs.field( + repr=False, metadata=docs("The command context") + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class ComponentError(_Error): + """Dispatched when the library encounters an error in a component.""" + + ctx: "ComponentContext" = attrs.field(repr=False, metadata=docs("The component context")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class AutocompleteError(_Error): + """Dispatched when the library encounters an error in an autocomplete.""" + + ctx: "AutocompleteContext" = attrs.field(repr=False, metadata=docs("The autocomplete context")) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class ModalError(_Error): + """Dispatched when the library encounters an error in a modal.""" + + ctx: "ModalContext" = attrs.field(repr=False, metadata=docs("The modal context")) diff --git a/naff/api/events/processors/__init__.py b/naff/api/events/processors/__init__.py index 2c3b215bd..25089671e 100644 --- a/naff/api/events/processors/__init__.py +++ b/naff/api/events/processors/__init__.py @@ -1,5 +1,6 @@ from .channel_events import * from .guild_events import * +from .integrations import * from .member_events import * from .message_events import * from .reaction_events import * diff --git a/naff/api/events/processors/_template.py b/naff/api/events/processors/_template.py index c10fd33a3..5f258f63c 100644 --- a/naff/api/events/processors/_template.py +++ b/naff/api/events/processors/_template.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Callable, Coroutine from naff.client.const import Absent, MISSING - from naff.models.discord.user import NaffUser if TYPE_CHECKING: diff --git a/naff/api/events/processors/channel_events.py b/naff/api/events/processors/channel_events.py index 2ba3bd7a6..8086334d1 100644 --- a/naff/api/events/processors/channel_events.py +++ b/naff/api/events/processors/channel_events.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING import naff.api.events as events +from naff.client.const import MISSING from naff.models.discord.channel import BaseChannel from naff.models.discord.invite import Invite -from naff.client.const import MISSING from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: diff --git a/naff/api/events/processors/guild_events.py b/naff/api/events/processors/guild_events.py index e397b6d9b..c1fdb6ccf 100644 --- a/naff/api/events/processors/guild_events.py +++ b/naff/api/events/processors/guild_events.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from naff.client.const import MISSING -from ._template import EventMixinTemplate, Processor -from naff.models import GuildIntegration, Sticker, to_snowflake from naff.api.events.discord import ( GuildEmojisUpdate, IntegrationCreate, @@ -16,6 +12,9 @@ GuildStickersUpdate, WebhooksUpdate, ) +from naff.client.const import MISSING +from naff.models import GuildIntegration, Sticker, to_snowflake +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent @@ -33,6 +32,11 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None: event: raw guild create event """ + new_guild: bool = True + if self.cache.get_guild(event.data["id"]): + # guild already cached, most likely an unavailable guild coming back online + new_guild = False + guild = self.cache.place_guild_data(event.data) self._user._guild_ids.add(to_snowflake(event.data.get("id"))) # noqa : w0212 @@ -43,7 +47,10 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None: # delays events until chunking has completed await guild.chunk() - self.dispatch(events.GuildJoin(guild)) + if new_guild: + self.dispatch(events.GuildJoin(guild.id)) + else: + self.dispatch(events.GuildAvailable(guild.id)) @Processor.define() async def _on_raw_guild_update(self, event: "RawGatewayEvent") -> None: @@ -54,12 +61,7 @@ async def _on_raw_guild_update(self, event: "RawGatewayEvent") -> None: async def _on_raw_guild_delete(self, event: "RawGatewayEvent") -> None: guild_id = int(event.data.get("id")) if event.data.get("unavailable", False): - self.dispatch( - events.GuildUnavailable( - guild_id, - self.cache.get_guild(guild_id) or MISSING, - ) - ) + self.dispatch(events.GuildUnavailable(guild_id)) else: # noinspection PyProtectedMember if guild_id in self._user._guild_ids: @@ -70,12 +72,7 @@ async def _on_raw_guild_delete(self, event: "RawGatewayEvent") -> None: guild = self.cache.get_guild(guild_id) self.cache.delete_guild(guild_id) - self.dispatch( - events.GuildLeft( - guild_id, - guild or MISSING, - ) - ) + self.dispatch(events.GuildLeft(guild)) @Processor.define() async def _on_raw_guild_ban_add(self, event: "RawGatewayEvent") -> None: diff --git a/naff/api/events/processors/integrations.py b/naff/api/events/processors/integrations.py new file mode 100644 index 000000000..e8eaf84c1 --- /dev/null +++ b/naff/api/events/processors/integrations.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING + +from naff.models.discord.app_perms import ApplicationCommandPermission, CommandPermissions +from naff.models.discord.snowflake import to_snowflake +from ._template import EventMixinTemplate, Processor +from ... import events + +if TYPE_CHECKING: + from naff.api.events import RawGatewayEvent + +__all__ = ("IntegrationEvents",) + + +class IntegrationEvents(EventMixinTemplate): + @Processor.define() + async def _raw_application_command_permissions_update(self, event: "RawGatewayEvent") -> None: + perms = [ApplicationCommandPermission.from_dict(perm, self) for perm in event.data["permissions"]] + guild_id = to_snowflake(event.data["guild_id"]) + command_id = to_snowflake(event.data["id"]) + + if guild := self.get_guild(guild_id): + if guild.permissions: + if command_id not in guild.command_permissions: + guild.command_permissions[command_id] = CommandPermissions( + client=self, command_id=command_id, guild=guild + ) + + command_permissions = guild.command_permissions[command_id] + command_permissions.update_permissions(*perms) + + self.dispatch(events.ApplicationCommandPermissionsUpdate(guild, perms)) diff --git a/naff/api/events/processors/member_events.py b/naff/api/events/processors/member_events.py index d5a2e9bb4..80fb771b4 100644 --- a/naff/api/events/processors/member_events.py +++ b/naff/api/events/processors/member_events.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING import naff.api.events as events - from naff.client.const import MISSING from ._template import EventMixinTemplate, Processor diff --git a/naff/api/events/processors/message_events.py b/naff/api/events/processors/message_events.py index e07566812..a3c23b994 100644 --- a/naff/api/events/processors/message_events.py +++ b/naff/api/events/processors/message_events.py @@ -2,10 +2,8 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from naff.client.const import logger -from ._template import EventMixinTemplate, Processor from naff.models import to_snowflake, BaseMessage +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent @@ -55,7 +53,7 @@ async def _on_raw_message_delete(self, event: "RawGatewayEvent") -> None: if not message: message = BaseMessage.from_dict(event.data, self) self.cache.delete_message(event.data["channel_id"], event.data["id"]) - logger.debug(f"Dispatching Event: {event.resolved_name}") + self.logger.debug(f"Dispatching Event: {event.resolved_name}") self.dispatch(events.MessageDelete(message)) @Processor.define() diff --git a/naff/api/events/processors/reaction_events.py b/naff/api/events/processors/reaction_events.py index b4997fa05..67bc688c9 100644 --- a/naff/api/events/processors/reaction_events.py +++ b/naff/api/events/processors/reaction_events.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from ._template import EventMixinTemplate, Processor from naff.models import PartialEmoji, Reaction +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent diff --git a/naff/api/events/processors/role_events.py b/naff/api/events/processors/role_events.py index ca0de2ef8..7d66b7bdf 100644 --- a/naff/api/events/processors/role_events.py +++ b/naff/api/events/processors/role_events.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING import naff.api.events as events - from naff.client.const import MISSING from ._template import EventMixinTemplate, Processor diff --git a/naff/api/events/processors/stage_events.py b/naff/api/events/processors/stage_events.py index efdf91b50..2109fe6bc 100644 --- a/naff/api/events/processors/stage_events.py +++ b/naff/api/events/processors/stage_events.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from ._template import EventMixinTemplate, Processor from naff.models import StageInstance +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent diff --git a/naff/api/events/processors/thread_events.py b/naff/api/events/processors/thread_events.py index 5b4886111..232aea8cf 100644 --- a/naff/api/events/processors/thread_events.py +++ b/naff/api/events/processors/thread_events.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from ._template import EventMixinTemplate, Processor from naff.models import to_snowflake +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent @@ -14,7 +13,10 @@ class ThreadEvents(EventMixinTemplate): @Processor.define() async def _on_raw_thread_create(self, event: "RawGatewayEvent") -> None: - self.dispatch(events.ThreadCreate(self.cache.place_channel_data(event.data))) + thread = self.cache.place_channel_data(event.data) + if event.data.get("newly_created"): + self.dispatch(events.NewThreadCreate(thread)) + self.dispatch(events.ThreadCreate(thread)) @Processor.define() async def _on_raw_thread_update(self, event: "RawGatewayEvent") -> None: diff --git a/naff/api/events/processors/user_events.py b/naff/api/events/processors/user_events.py index 2f7e80d90..c414c4bd2 100644 --- a/naff/api/events/processors/user_events.py +++ b/naff/api/events/processors/user_events.py @@ -1,9 +1,9 @@ from typing import Union, TYPE_CHECKING -import naff.api.events as events -from ._template import EventMixinTemplate, Processor +import naff.api.events as events from naff.models import User, Member, BaseChannel, Timestamp, to_snowflake, Activity from naff.models.discord.enums import Status +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent diff --git a/naff/api/events/processors/voice_events.py b/naff/api/events/processors/voice_events.py index 9e237741e..c1e8ac2a8 100644 --- a/naff/api/events/processors/voice_events.py +++ b/naff/api/events/processors/voice_events.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING import naff.api.events as events - from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: @@ -24,6 +23,18 @@ async def _on_raw_voice_state_update(self, event: "RawGatewayEvent") -> None: # noinspection PyProtectedMember await vc._voice_state_update(before, after, event.data) + if before and after: + if (before.mute != after.mute) or (before.self_mute != after.self_mute): + self.dispatch(events.VoiceUserMute(after, after.member, after.channel, after.mute or after.self_mute)) + if (before.deaf != after.deaf) or (before.self_deaf != after.self_deaf): + self.dispatch(events.VoiceUserDeafen(after, after.member, after.channel, after.deaf or after.self_deaf)) + if before.channel != after.channel: + self.dispatch(events.VoiceUserMove(after, after.member, before.channel, after.channel)) + elif not before and after: + self.dispatch(events.VoiceUserJoin(after, after.member, after.channel)) + elif before and not after: + self.dispatch(events.VoiceUserLeave(before, before.member, before.channel)) + @Processor.define() async def _on_raw_voice_server_update(self, event: "RawGatewayEvent") -> None: if vc := self.cache.get_bot_voice_state(event.data["guild_id"]): diff --git a/naff/api/gateway/gateway.py b/naff/api/gateway/gateway.py index 149cb4759..ef4c7b060 100644 --- a/naff/api/gateway/gateway.py +++ b/naff/api/gateway/gateway.py @@ -7,7 +7,7 @@ from typing import TypeVar, TYPE_CHECKING from naff.api import events -from naff.client.const import logger, MISSING, __api_version__ +from naff.client.const import MISSING, __api_version__ from naff.client.utils.input_utils import OverriddenJson from naff.client.utils.serializer import dict_filter_none from naff.models.discord.enums import Status @@ -176,31 +176,31 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None: match op: case OPCODE.HEARTBEAT: - logger.debug("Received heartbeat request from gateway") + self.logger.debug("Received heartbeat request from gateway") return await self.send_heartbeat() case OPCODE.HEARTBEAT_ACK: self.latency.append(time.perf_counter() - self._last_heartbeat) if self._last_heartbeat != 0 and self.latency[-1] >= 15: - logger.warning( + self.logger.warning( f"High Latency! shard ID {self.shard[0]} heartbeat took {self.latency[-1]:.1f}s to be acknowledged!" ) else: - logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds") + self.logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds") return self._acknowledged.set() case OPCODE.RECONNECT: - logger.debug("Gateway requested reconnect. Reconnecting...") + self.logger.debug("Gateway requested reconnect. Reconnecting...") return await self.reconnect(resume=True, url=self.ws_resume_url) case OPCODE.INVALIDATE_SESSION: - logger.warning("Gateway has invalidated session! Reconnecting...") + self.logger.warning("Gateway has invalidated session! Reconnecting...") return await self.reconnect() case _: - return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}") + return self.logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}") async def dispatch_event(self, data, seq, event) -> None: match event: @@ -212,13 +212,12 @@ async def dispatch_event(self, data, seq, event) -> None: self.ws_resume_url = ( f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream" ) - logger.info(f"Shard {self.shard[0]} has connected to gateway!") - logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}") - # todo: future polls, improve guild caching here. run the debugger. you'll see why + self.logger.info(f"Shard {self.shard[0]} has connected to gateway!") + self.logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}") return self.state.client.dispatch(events.WebsocketReady(data)) case "RESUMED": - logger.info(f"Successfully resumed connection! Session_ID: {self.session_id}") + self.logger.info(f"Successfully resumed connection! Session_ID: {self.session_id}") self.state.client.dispatch(events.Resume()) return @@ -233,9 +232,9 @@ async def dispatch_event(self, data, seq, event) -> None: try: asyncio.create_task(processor(events.RawGatewayEvent(data.copy(), override_name=event_name))) except Exception as ex: - logger.error(f"Failed to run event processor for {event_name}: {ex}") + self.logger.error(f"Failed to run event processor for {event_name}: {ex}") else: - logger.debug(f"No processor for `{event_name}`") + self.logger.debug(f"No processor for `{event_name}`") self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name="raw_gateway_event")) self.state.client.dispatch(events.RawGatewayEvent(data.copy(), override_name=f"raw_{event.lower()}")) @@ -264,7 +263,7 @@ async def _identify(self) -> None: serialized = OverriddenJson.dumps(payload) await self.ws.send_str(serialized) - logger.debug( + self.logger.debug( f"Shard ID {self.shard[0]} has identified itself to Gateway, requesting intents: {self.state.intents}!" ) @@ -286,11 +285,11 @@ async def _resume_connection(self) -> None: serialized = OverriddenJson.dumps(payload) await self.ws.send_str(serialized) - logger.debug(f"{self.shard[0]} is attempting to resume a connection") + self.logger.debug(f"{self.shard[0]} is attempting to resume a connection") async def send_heartbeat(self) -> None: await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True) - logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat") + self.logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat") async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None: """Update the bot's presence status.""" diff --git a/naff/api/gateway/state.py b/naff/api/gateway/state.py index 7ccf32e99..a6c2a322e 100644 --- a/naff/api/gateway/state.py +++ b/naff/api/gateway/state.py @@ -1,13 +1,15 @@ import asyncio import traceback from datetime import datetime +from logging import Logger from typing import TYPE_CHECKING, Optional, Union +import attrs + import naff from naff.api import events -from naff.client.const import logger, MISSING, Absent +from naff.client.const import Absent, MISSING, get_logger from naff.client.errors import NaffException, WebSocketClosed -from naff.client.utils.attr_utils import define, field from naff.models.discord.activity import Activity from naff.models.discord.enums import Intents, Status, ActivityType from .gateway import GatewayClient @@ -18,7 +20,7 @@ __all__ = ("ConnectionState",) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ConnectionState: client: "Client" """The bot's client""" @@ -26,7 +28,7 @@ class ConnectionState: """The event intents in use""" shard_id: int """The shard ID of this state""" - _shard_ready: asyncio.Event = field(default=None) + _shard_ready: asyncio.Event = attrs.field(repr=False, default=None) """Indicates that this state is now ready""" gateway: Absent[GatewayClient] = MISSING @@ -43,6 +45,8 @@ class ConnectionState: _shard_task: asyncio.Task | None = None + logger: Logger = attrs.field(repr=False, init=False, factory=get_logger) + def __attrs_post_init__(self, *args, **kwargs) -> None: self._shard_ready = asyncio.Event() @@ -68,7 +72,7 @@ async def start(self) -> None: """Connect to the Discord Gateway.""" self.gateway_url = await self.client.http.get_gateway() - logger.debug(f"Starting Shard ID {self.shard_id}") + self.logger.debug(f"Starting Shard ID {self.shard_id}") self.start_time = datetime.now() self._shard_task = asyncio.create_task(self._ws_connect()) @@ -80,7 +84,7 @@ async def start(self) -> None: async def stop(self) -> None: """Disconnect from the Discord Gateway.""" - logger.debug(f"Shutting down shard ID {self.shard_id}") + self.logger.debug(f"Shutting down shard ID {self.shard_id}") if self.gateway is not None: self.gateway.close() self.gateway = None @@ -98,7 +102,7 @@ def clear_ready(self) -> None: async def _ws_connect(self) -> None: """Connect to the Discord Gateway.""" - logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...") + self.logger.info(f"Shard {self.shard_id} is attempting to connect to gateway...") try: async with GatewayClient(self, (self.shard_id, self.client.total_shards)) as self.gateway: try: @@ -123,7 +127,7 @@ async def _ws_connect(self) -> None: except Exception as e: self.client.dispatch(events.Disconnect()) - logger.error("".join(traceback.format_exception(type(e), e, e.__traceback__))) + self.logger.error("".join(traceback.format_exception(type(e), e, e.__traceback__))) async def change_presence( self, status: Optional[Union[str, Status]] = Status.ONLINE, activity: Absent[Union[Activity, str]] = MISSING @@ -149,7 +153,7 @@ async def change_presence( if activity.type == ActivityType.STREAMING: if not activity.url: - logger.warning("Streaming activity cannot be set without a valid URL attribute") + self.logger.warning("Streaming activity cannot be set without a valid URL attribute") elif activity.type not in [ ActivityType.GAME, ActivityType.STREAMING, @@ -157,7 +161,9 @@ async def change_presence( ActivityType.WATCHING, ActivityType.COMPETING, ]: - logger.warning(f"Activity type `{ActivityType(activity.type).name}` may not be enabled for bots") + self.logger.warning( + f"Activity type `{ActivityType(activity.type).name}` may not be enabled for bots" + ) else: activity = self.client.activity @@ -172,7 +178,7 @@ async def change_presence( if self.client.status: status = self.client.status else: - logger.warning("Status must be set to a valid status type, defaulting to online") + self.logger.warning("Status must be set to a valid status type, defaulting to online") status = Status.ONLINE self.client._status = status diff --git a/naff/api/gateway/websocket.py b/naff/api/gateway/websocket.py index 7c97362b7..c85e59e25 100644 --- a/naff/api/gateway/websocket.py +++ b/naff/api/gateway/websocket.py @@ -5,14 +5,13 @@ import zlib from abc import abstractmethod from types import TracebackType -from aiohttp import WSMsgType from typing import TypeVar, TYPE_CHECKING -from naff.client.const import logger -from naff.client.errors import WebSocketClosed -from naff.models.naff.cooldowns import CooldownSystem +from aiohttp import WSMsgType +from naff.client.errors import WebSocketClosed from naff.client.utils.input_utils import OverriddenJson +from naff.models.naff.cooldowns import CooldownSystem if TYPE_CHECKING: from naff.api.gateway.state import ConnectionState @@ -41,6 +40,7 @@ async def rate_limit(self) -> None: class WebsocketClient: def __init__(self, state: "ConnectionState") -> None: self.state = state + self.logger = state.client.logger self.ws = None self.ws_url = None @@ -134,11 +134,11 @@ async def send(self, data: str, bypass=False) -> None: bypass: Should the rate limit be ignored for this send (used for heartbeats) """ - logger.debug(f"Sending data to websocket: {data}") + self.logger.debug(f"Sending data to websocket: {data}") async with self._race_lock: if self.ws is None: - return logger.warning("Attempted to send data while websocket is not connected!") + return self.logger.warning("Attempted to send data while websocket is not connected!") if not bypass: await self.rl_manager.rate_limit() @@ -177,7 +177,7 @@ async def receive(self, force: bool = False) -> str: resp = await self.ws.receive() if resp.type == WSMsgType.CLOSE: - logger.debug(f"Disconnecting from gateway! Reason: {resp.data}::{resp.extra}") + self.logger.debug(f"Disconnecting from gateway! Reason: {resp.data}::{resp.extra}") if resp.data >= 4000: # This should propagate to __aexit__() which will forcefully shut down everything # and cleanup correctly. @@ -232,7 +232,7 @@ async def receive(self, force: bool = False) -> str: try: msg = OverriddenJson.loads(msg) except Exception as e: - logger.error(e) + self.logger.error(e) continue return msg @@ -270,7 +270,7 @@ async def run_bee_gees(self) -> None: await self._start_bee_gees() except Exception: self.close() - logger.error("The heartbeater raised an exception!", exc_info=True) + self.logger.error("The heartbeater raised an exception!", exc_info=True) async def _start_bee_gees(self) -> None: if self.heartbeat_interval is None: @@ -283,10 +283,10 @@ async def _start_bee_gees(self) -> None: else: return - logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds") + self.logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds") while not self._kill_bee_gees.is_set(): if not self._acknowledged.is_set(): - logger.warning( + self.logger.warning( f"Heartbeat has not been acknowledged for {self.heartbeat_interval} seconds," " likely zombied connection. Reconnect!" ) diff --git a/naff/api/http/http_client.py b/naff/api/http/http_client.py index c2b8847e0..44aab2f96 100644 --- a/naff/api/http/http_client.py +++ b/naff/api/http/http_client.py @@ -1,6 +1,7 @@ """This file handles the interaction with discords http endpoints.""" import asyncio import time +from logging import Logger from typing import Any, cast from urllib.parse import quote as _uriquote from weakref import WeakValueDictionary @@ -10,6 +11,7 @@ from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData from multidict import CIMultiDictProxy +import naff.client.const as constants from naff import models from naff.api.http.http_requests import ( BotRequests, @@ -27,13 +29,14 @@ ScheduledEventsRequests, ) from naff.client.const import ( + MISSING, __py_version__, __repo_url__, __version__, - logger, __api_version__, ) from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError +from naff.client.mixins.serialization import DictSerializationMixin from naff.client.utils.input_utils import response_decode, OverriddenJson from naff.client.utils.serializer import dict_filter from naff.models.discord.file import UPLOADABLE_TYPE @@ -46,7 +49,7 @@ class GlobalLock: def __init__(self) -> None: self._lock = asyncio.Lock() self.max_requests = 45 - self._calls = 0 + self._calls = self.max_requests self._reset_time = 0 @property @@ -78,7 +81,7 @@ async def wait(self) -> None: elif self._calls <= 0: await asyncio.sleep(self._reset_time - time.perf_counter()) self.reset_calls() - self._calls -= 1 + self._calls -= 1 class BucketLock: @@ -156,7 +159,7 @@ class HTTPClient( ): """A http client for sending requests to the Discord API.""" - def __init__(self, connector: BaseConnector | None = None) -> None: + def __init__(self, connector: BaseConnector | None = None, logger: Logger = MISSING) -> None: self.connector: BaseConnector | None = connector self.__session: ClientSession | None = None self.token: str | None = None @@ -170,6 +173,10 @@ def __init__(self, connector: BaseConnector | None = None) -> None: f"DiscordBot ({__repo_url__} {__version__} Python/{__py_version__}) aiohttp/{aiohttp.__version__}" ) + if logger is MISSING: + logger = constants.get_logger() + self.logger = logger + def get_ratelimit(self, route: Route) -> BucketLock: """ Get a route's rate limit bucket. @@ -203,7 +210,7 @@ def ingest_ratelimit(self, route: Route, header: CIMultiDictProxy, bucket_lock: if bucket_lock.bucket_hash: # We only ever try and cache the bucket if the bucket hash has been set (ignores unlimited endpoints) - logger.debug(f"Caching ingested rate limit data for: {bucket_lock.bucket_hash}") + self.logger.debug(f"Caching ingested rate limit data for: {bucket_lock.bucket_hash}") self._endpoints[route.rl_bucket] = bucket_lock.bucket_hash self.ratelimit_locks[bucket_lock.bucket_hash] = bucket_lock @@ -226,6 +233,13 @@ def _process_payload( if isinstance(payload, dict): payload = dict_filter(payload) + + for k, v in payload.items(): + if isinstance(v, DictSerializationMixin): + payload[k] = v.to_dict() + if isinstance(v, (list, tuple, set)): + payload[k] = [i.to_dict() if isinstance(i, DictSerializationMixin) else i for i in v] + else: payload = [dict_filter(x) if isinstance(x, dict) else x for x in payload] @@ -305,14 +319,14 @@ async def request( if result.get("global", False): # global ratelimit is reached # if we get a global, that's pretty bad, this would usually happen if the user is hitting the api from 2 clients sharing a token - logger.error( + self.logger.error( f"Bot has exceeded global ratelimit, locking REST API for {result['retry_after']} seconds" ) self.global_lock.set_reset_time(float(result["retry_after"])) continue elif result.get("message") == "The resource is being rate limited.": # resource ratelimit is reached - logger.warning( + self.logger.warning( f"{route.endpoint} The resource is being rate limited! " f"Reset in {result.get('retry_after')} seconds" ) @@ -323,21 +337,21 @@ async def request( # endpoint ratelimit is reached # 429's are unfortunately unavoidable, but we can attempt to avoid them # so long as these are infrequent we're doing well - logger.warning( + self.logger.warning( f"{route.endpoint} Has exceeded it's ratelimit ({lock.limit})! Reset in {lock.delta} seconds" ) await lock.defer_unlock() # lock this route and wait for unlock continue elif lock.remaining == 0: # Last call available in the bucket, lock until reset - logger.debug( + self.logger.debug( f"{route.endpoint} Has exhausted its ratelimit ({lock.limit})! Locking route for {lock.delta} seconds" ) await lock.blind_defer_unlock() # lock this route, but continue processing the current response elif response.status in {500, 502, 504}: # Server issues, retry - logger.warning( + self.logger.warning( f"{route.endpoint} Received {response.status}... retrying in {1 + attempt * 2} seconds" ) await asyncio.sleep(1 + attempt * 2) @@ -346,7 +360,7 @@ async def request( if not 300 > response.status >= 200: await self._raise_exception(response, route, result) - logger.debug( + self.logger.debug( f"{route.endpoint} Received {response.status} :: [{lock.remaining}/{lock.limit} calls remaining]" ) return result @@ -357,7 +371,7 @@ async def request( raise async def _raise_exception(self, response, route, result) -> None: - logger.error(f"{route.method}::{route.url}: {response.status}") + self.logger.error(f"{route.method}::{route.url}: {response.status}") if response.status == 403: raise Forbidden(response, response_data=result, route=route) @@ -369,7 +383,7 @@ async def _raise_exception(self, response, route, result) -> None: raise HTTPException(response, response_data=result, route=route) async def request_cdn(self, url, asset) -> bytes: # pyright: ignore [reportGeneralTypeIssues] - logger.debug(f"{asset} requests {url} from CDN") + self.logger.debug(f"{asset} requests {url} from CDN") async with self.__session.get(url) as response: if response.status == 200: return await response.read() diff --git a/naff/api/http/http_requests/bot.py b/naff/api/http/http_requests/bot.py index fede880f2..49f2393c4 100644 --- a/naff/api/http/http_requests/bot.py +++ b/naff/api/http/http_requests/bot.py @@ -3,7 +3,6 @@ import discord_typings from naff.models.naff.protocols import CanRequest - from ..route import Route __all__ = ("BotRequests",) diff --git a/naff/api/http/http_requests/channels.py b/naff/api/http/http_requests/channels.py index 81be399ee..ae70c518e 100644 --- a/naff/api/http/http_requests/channels.py +++ b/naff/api/http/http_requests/channels.py @@ -567,7 +567,7 @@ async def create_tag( payload: PAYLOAD_TYPE = { "name": name, "emoji_id": int(emoji_id) if emoji_id else None, - "emoji_name": emoji_name, + "emoji_name": emoji_name if emoji_name else None, } payload = dict_filter_none(payload) diff --git a/naff/api/http/http_requests/guild.py b/naff/api/http/http_requests/guild.py index 9a8e43cbf..f5de63862 100644 --- a/naff/api/http/http_requests/guild.py +++ b/naff/api/http/http_requests/guild.py @@ -4,7 +4,6 @@ from naff.client.utils.serializer import dict_filter_none from naff.models.naff.protocols import CanRequest - from ..route import Route, PAYLOAD_TYPE __all__ = ("GuildRequests",) @@ -133,7 +132,7 @@ async def modify_guild( ) payload = kwargs.copy() for key, value in kwargs.items(): - if key not in expected or value is None: # todo review + if key not in expected or value is None: del payload[key] # only do the request if there is something to modify diff --git a/naff/api/http/http_requests/interactions.py b/naff/api/http/http_requests/interactions.py index ba80fed8c..7df7b4aa0 100644 --- a/naff/api/http/http_requests/interactions.py +++ b/naff/api/http/http_requests/interactions.py @@ -2,8 +2,8 @@ import discord_typings -from naff.models.naff.protocols import CanRequest from naff.client.const import GLOBAL_SCOPE +from naff.models.naff.protocols import CanRequest from ..route import Route __all__ = ("InteractionRequests",) @@ -62,8 +62,8 @@ async def get_application_commands( ) async def overwrite_application_commands( - self, app_id: "Snowflake_Type", data: list[dict], guild_id: "Snowflake_Type" # todo type "data" - ) -> list[discord_typings.ApplicationCommandData]: + self, app_id: "Snowflake_Type", data: list[dict], guild_id: "Snowflake_Type" + ) -> list[discord_typings.ApplicationCommandData]: # todo type "data" """ Take a list of commands and overwrite the existing command list within the given scope @@ -142,7 +142,7 @@ async def edit_interaction_message( payload: dict, application_id: "Snowflake_Type", token: str, - message_id: str = "@original", + message_id: "str|Snowflake_Type" = "@original", files: list["UPLOADABLE_TYPE"] | None = None, ) -> discord_typings.MessageData: """ @@ -166,6 +166,20 @@ async def edit_interaction_message( ) return cast(discord_typings.MessageData, result) + async def delete_interaction_message( + self, application_id: "Snowflake_Type", token: str, message_id: "str | Snowflake_Type" = "@original" + ) -> None: + """ + Deletes an existing interaction message. + + Args: + application_id: The id of the application. + token: The token of the interaction. + message_id: The target message to delete. Defaults to @original which represents the initial response message. + + """ + return await self.request(Route("DELETE", f"/webhooks/{int(application_id)}/{token}/messages/{message_id}")) + async def get_interaction_message( self, application_id: "Snowflake_Type", token: str, message_id: str = "@original" ) -> discord_typings.MessageData: diff --git a/naff/api/http/http_requests/members.py b/naff/api/http/http_requests/members.py index ce64dd21b..e18a4e4cf 100644 --- a/naff/api/http/http_requests/members.py +++ b/naff/api/http/http_requests/members.py @@ -1,13 +1,13 @@ -from typing import TYPE_CHECKING, cast from datetime import datetime +from typing import TYPE_CHECKING, cast import discord_typings from naff.client.const import Missing, MISSING +from naff.client.utils.serializer import dict_filter_none +from naff.models.discord.timestamp import Timestamp from naff.models.naff.protocols import CanRequest from ..route import Route, PAYLOAD_TYPE -from naff.models.discord.timestamp import Timestamp -from naff.client.utils.serializer import dict_filter_none __all__ = ("MemberRequests",) diff --git a/naff/api/http/http_requests/messages.py b/naff/api/http/http_requests/messages.py index 763eab6b2..25a8a80d8 100644 --- a/naff/api/http/http_requests/messages.py +++ b/naff/api/http/http_requests/messages.py @@ -3,7 +3,6 @@ import discord_typings from naff.models.naff.protocols import CanRequest - from ..route import Route __all__ = ("MessageRequests",) diff --git a/naff/api/http/http_requests/threads.py b/naff/api/http/http_requests/threads.py index 6b317542c..af1ee1413 100644 --- a/naff/api/http/http_requests/threads.py +++ b/naff/api/http/http_requests/threads.py @@ -3,10 +3,10 @@ import discord_typings from aiohttp import FormData +from naff.api.http.route import Route from naff.client.const import MISSING, Absent -from naff.client.utils.attr_converters import timestamp_converter from naff.models.discord.enums import ChannelTypes -from naff.api.http.route import Route +from naff.models.discord.timestamp import Timestamp __all__ = ("ThreadRequests",) @@ -93,7 +93,7 @@ async def list_public_archived_threads( if limit: payload["limit"] = limit if before: - payload["before"] = timestamp_converter(before) + payload["before"] = Timestamp.from_snowflake(before).isoformat() return await self.request(Route("GET", f"/channels/{channel_id}/threads/archived/public"), params=payload) async def list_private_archived_threads( @@ -115,7 +115,7 @@ async def list_private_archived_threads( if limit: payload["limit"] = limit if before: - payload["before"] = before + payload["before"] = Timestamp.from_snowflake(before).isoformat() return await self.request(Route("GET", f"/channels/{channel_id}/threads/archived/private"), params=payload) async def list_joined_private_archived_threads( @@ -210,23 +210,23 @@ async def create_forum_thread( name: The name of the thread auto_archive_duration: Time before the thread will be automatically archived. Note 3 day and 7 day archive durations require the server to be boosted. message: The message-content for the post/thread + applied_tags: The tags to apply to the thread rate_limit_per_user: The time users must wait between sending messages + files: The files to upload reason: The reason for creating this thread Returns: The created thread object """ - # note: `{"use_nested_fields": 1}` seems to be a temporary flag until forums launch return await self.request( Route("POST", f"/channels/{channel_id}/threads"), payload={ "name": name, "auto_archive_duration": auto_archive_duration, "rate_limit_per_user": rate_limit_per_user, - "applied_tags": applied_tags, "message": message, + "applied_tags": applied_tags, }, - params={"use_nested_fields": 1}, files=files, reason=reason, ) diff --git a/naff/api/http/http_requests/webhooks.py b/naff/api/http/http_requests/webhooks.py index e0ea3e9a5..30f95a691 100644 --- a/naff/api/http/http_requests/webhooks.py +++ b/naff/api/http/http_requests/webhooks.py @@ -2,8 +2,8 @@ import discord_typings -from ..route import Route from naff.client.utils.serializer import dict_filter_none +from ..route import Route __all__ = ("WebhookRequests",) diff --git a/naff/api/http/route.py b/naff/api/http/route.py index 58c31691a..b2614dd4e 100644 --- a/naff/api/http/route.py +++ b/naff/api/http/route.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Optional from urllib.parse import quote as _uriquote + from naff.client.const import __api_version__ if TYPE_CHECKING: diff --git a/naff/api/voice/audio.py b/naff/api/voice/audio.py index b2319e8a9..afe61579c 100644 --- a/naff/api/voice/audio.py +++ b/naff/api/voice/audio.py @@ -66,6 +66,7 @@ class BaseAudio(ABC): def __del__(self) -> None: self.cleanup() + @abstractmethod def cleanup(self) -> None: """A method to optionally cleanup after this object is no longer required.""" ... diff --git a/naff/api/voice/encryption.py b/naff/api/voice/encryption.py index 3d7ad1b01..661bcf36d 100644 --- a/naff/api/voice/encryption.py +++ b/naff/api/voice/encryption.py @@ -32,7 +32,7 @@ def encrypt(self, mode: str, header: bytes, data) -> bytes: raise RuntimeError(f"Unsupported encryption type requested: {mode}") def xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes: - # todo: hi! + # todo: redundant but might be useful for some weird edge case ... def xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes: diff --git a/naff/api/voice/opus.py b/naff/api/voice/opus.py index 8279f01b9..f58daa2dc 100644 --- a/naff/api/voice/opus.py +++ b/naff/api/voice/opus.py @@ -1,8 +1,8 @@ -import os -import sys import array import ctypes import ctypes.util +import os +import sys from enum import IntEnum from typing import Any @@ -216,7 +216,7 @@ def set_expected_pack_loss(self, expected_packet_loss: float) -> None: self.lib_opus.opus_encoder_ctl(self.encoder, EncoderCTL.CTL_SET_PLP, self.expected_packet_loss) def encode(self, pcm: bytes) -> bytes: - """todo: doc""" + """Encode a frame of audio""" max_data_bytes = len(pcm) pcm = ctypes.cast(pcm, c_int16_ptr) data = (ctypes.c_char * max_data_bytes)() diff --git a/naff/api/voice/player.py b/naff/api/voice/player.py index 17a454a9b..7582cce5e 100644 --- a/naff/api/voice/player.py +++ b/naff/api/voice/player.py @@ -2,12 +2,12 @@ import shutil import threading from asyncio import AbstractEventLoop, run_coroutine_threadsafe +from logging import Logger from time import sleep, perf_counter from typing import Optional, TYPE_CHECKING from naff.api.voice.audio import BaseAudio, AudioVolume from naff.api.voice.opus import Encoder -from naff.client.const import logger if TYPE_CHECKING: from naff.models.naff.active_voice_state import ActiveVoiceState @@ -22,6 +22,7 @@ def __init__(self, audio, v_state, loop) -> None: self.current_audio: Optional[BaseAudio] = audio self.state: "ActiveVoiceState" = v_state self.loop: AbstractEventLoop = loop + self.logger: Logger = self.state.ws.logger self._encoder: Encoder = Encoder() @@ -99,14 +100,14 @@ def run(self) -> None: self._stopped.clear() asyncio.run_coroutine_threadsafe(self.state.ws.speaking(True), self.loop) - logger.debug(f"Now playing {self.current_audio!r}") + self.logger.debug(f"Now playing {self.current_audio!r}") start = None try: while not self._stop_event.is_set(): if not self.state.ws.ready.is_set() or not self._resume.is_set(): run_coroutine_threadsafe(self.state.ws.speaking(False), self.loop) - logger.debug("Voice playback has been suspended!") + self.logger.debug("Voice playback has been suspended!") wait_for = [] @@ -122,7 +123,7 @@ def run(self) -> None: continue run_coroutine_threadsafe(self.state.ws.speaking(), self.loop) - logger.debug("Voice playback has been resumed!") + self.logger.debug("Voice playback has been resumed!") start = None loops = 0 diff --git a/naff/api/voice/voice_gateway.py b/naff/api/voice/voice_gateway.py index 87d23d506..000230609 100644 --- a/naff/api/voice/voice_gateway.py +++ b/naff/api/voice/voice_gateway.py @@ -10,7 +10,6 @@ from naff.api.gateway.websocket import WebsocketClient from naff.api.voice.encryption import Encryption -from naff.client.const import logger from naff.client.errors import VoiceWebSocketClosed from naff.client.utils.input_utils import OverriddenJson @@ -106,7 +105,7 @@ async def receive(self, force=False) -> str: resp = await self.ws.receive() if resp.type == WSMsgType.CLOSE: - logger.debug(f"Disconnecting from voice gateway! Reason: {resp.data}::{resp.extra}") + self.logger.debug(f"Disconnecting from voice gateway! Reason: {resp.data}::{resp.extra}") if resp.data in (4006, 4009, 4014, 4015): # these are all recoverable close codes, anything else means we're foobared # codes: session expired, session timeout, disconnected, server crash @@ -159,7 +158,7 @@ async def receive(self, force=False) -> str: try: msg = OverriddenJson.loads(msg) except Exception as e: - logger.error(e) + self.logger.error(e) return msg @@ -169,26 +168,28 @@ async def dispatch_opcode(self, data, op) -> None: self.latency.append(time.perf_counter() - self._last_heartbeat) if self._last_heartbeat != 0 and self.latency[-1] >= 15: - logger.warning(f"High Latency! Voice heartbeat took {self.latency[-1]:.1f}s to be acknowledged!") + self.logger.warning( + f"High Latency! Voice heartbeat took {self.latency[-1]:.1f}s to be acknowledged!" + ) else: - logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds") + self.logger.debug(f"❤ Heartbeat acknowledged after {self.latency[-1]:.5f} seconds") return self._acknowledged.set() case OP.READY: - logger.debug("Discord send VC Ready! Establishing a socket connection...") + self.logger.debug("Discord send VC Ready! Establishing a socket connection...") self.voice_ip = data["ip"] self.voice_port = data["port"] self.ssrc = data["ssrc"] self.voice_modes = [mode for mode in data["modes"] if mode in Encryption.SUPPORTED] if len(self.voice_modes) == 0: - logger.critical("NO VOICE ENCRYPTION MODES SHARED WITH GATEWAY!") + self.logger.critical("NO VOICE ENCRYPTION MODES SHARED WITH GATEWAY!") await self.establish_voice_socket() case OP.SESSION_DESCRIPTION: - logger.info(f"Voice connection established; using {data['mode']}") + self.logger.info(f"Voice connection established; using {data['mode']}") self.encryptor = Encryption(data["secret_key"]) self.ready.set() if self.cond: @@ -196,7 +197,7 @@ async def dispatch_opcode(self, data, op) -> None: self.cond.notify() case _: - return logger.debug(f"Unhandled OPCODE: {op} = {data = }") + return self.logger.debug(f"Unhandled OPCODE: {op} = {data = }") async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None: async with self._race_lock: @@ -208,13 +209,13 @@ async def reconnect(self, *, resume: bool = False, code: int = 1012) -> None: self.ws = None if not resume: - logger.debug("Waiting for updated server information...") + self.logger.debug("Waiting for updated server information...") try: await asyncio.wait_for(self._voice_server_update.wait(), timeout=5) except asyncio.TimeoutError: self._kill_bee_gees.set() self.close() - logger.debug("Terminating VoiceGateway due to disconnection") + self.logger.debug("Terminating VoiceGateway due to disconnection") return self._voice_server_update.clear() @@ -248,7 +249,7 @@ async def _resume_connection(self) -> None: async def establish_voice_socket(self) -> None: """Establish the socket connection to discord""" - logger.debug("IP Discovery in progress...") + self.logger.debug("IP Discovery in progress...") self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.setblocking(False) @@ -260,14 +261,14 @@ async def establish_voice_socket(self) -> None: self.socket.sendto(packet, (self.voice_ip, self.voice_port)) resp = await self.loop.sock_recv(self.socket, 70) - logger.debug(f"Voice Initial Response Received: {resp}") + self.logger.debug(f"Voice Initial Response Received: {resp}") ip_start = 4 ip_end = resp.index(0, ip_start) self.me_ip = resp[ip_start:ip_end].decode("ascii") self.me_port = struct.unpack_from(">H", resp, len(resp) - 2)[0] - logger.debug(f"IP Discovered: {self.me_ip} #{self.me_port}") + self.logger.debug(f"IP Discovered: {self.me_ip} #{self.me_port}") await self._select_protocol() @@ -301,7 +302,7 @@ def send_packet(self, data: bytes, encoder, needs_encode=True) -> None: async def send_heartbeat(self) -> None: await self.send_json({"op": OP.HEARTBEAT, "d": random.uniform(0.0, 1.0)}) - logger.debug("❤ Voice Connection is sending Heartbeat") + self.logger.debug("❤ Voice Connection is sending Heartbeat") async def _identify(self) -> None: """Send an identify payload to the voice gateway.""" @@ -317,7 +318,7 @@ async def _identify(self) -> None: serialized = OverriddenJson.dumps(payload) await self.ws.send_str(serialized) - logger.debug("Voice Connection has identified itself to Voice Gateway") + self.logger.debug("Voice Connection has identified itself to Voice Gateway") async def _select_protocol(self) -> None: """Inform Discord of our chosen protocol.""" diff --git a/naff/client/auto_shard_client.py b/naff/client/auto_shard_client.py index 2f282a87d..4b3c14be7 100644 --- a/naff/client/auto_shard_client.py +++ b/naff/client/auto_shard_client.py @@ -4,16 +4,16 @@ from typing import TYPE_CHECKING, Optional import naff.api.events as events +from naff.api.events import ShardConnect from naff.api.gateway.state import ConnectionState from naff.client.client import Client -from naff.client.const import logger, MISSING +from naff.client.const import MISSING from naff.models import ( Guild, to_snowflake, ) -from naff.models.naff.listener import Listener from naff.models.discord import Status, Activity -from naff.api.events import ShardConnect +from naff.models.naff.listener import Listener if TYPE_CHECKING: from naff.models import Snowflake_Type @@ -75,7 +75,7 @@ def latencies(self) -> dict[int, float]: async def stop(self) -> None: """Shutdown the bot.""" - logger.debug("Stopping the bot.") + self.logger.debug("Stopping the bot.") self._ready.clear() await self.http.close() await asyncio.gather(*(state.stop() for state in self._connection_states)) @@ -135,7 +135,7 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: try: await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) except asyncio.TimeoutError: - logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") + self.logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") break self._guild_event.clear() if all(self.cache.get_guild(g_id) is not None for g_id in expected_guilds): @@ -143,16 +143,16 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: break if self.fetch_members: - logger.info(f"Shard {shard_id} is waiting for members to be chunked") + self.logger.info(f"Shard {shard_id} is waiting for members to be chunked") await asyncio.gather(*(guild.chunked.wait() for guild in self.guilds if guild.id in expected_guilds)) else: - logger.warning( + self.logger.warning( f"Shard {shard_id} reports it has 0 guilds, this is an indicator you may be using too many shards" ) # noinspection PyProtectedMember connection_state._shard_ready.set() self.dispatch(ShardConnect(shard_id)) - logger.debug(f"Shard {shard_id} is now ready") + self.logger.debug(f"Shard {shard_id} is now ready") # noinspection PyProtectedMember await asyncio.gather(*[shard._shard_ready.wait() for shard in self._connection_states]) @@ -162,7 +162,7 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: try: await asyncio.gather(*self.async_startup_tasks) except Exception as e: - self.dispatch(events.Error("async-extension-loader", e)) + self.dispatch(events.Error(source="async-extension-loader", error=e)) # cache slash commands if not self._startup: @@ -182,7 +182,7 @@ async def astart(self, token: str) -> None: Args: token: Your bot's token """ - logger.debug("Starting http client...") + self.logger.debug("Starting http client...") await self.login(token) tasks = [] @@ -195,7 +195,7 @@ async def astart(self, token: str) -> None: for bucket in shard_buckets.values(): for shard in bucket: - logger.debug(f"Starting {shard.shard_id}") + self.logger.debug(f"Starting {shard.shard_id}") start = time.perf_counter() tasks.append(asyncio.create_task(shard.start())) @@ -233,11 +233,11 @@ async def login(self, token) -> None: self.total_shards = data["shards"] elif data["shards"] != self.total_shards: recommended_shards = data["shards"] - logger.info( + self.logger.info( f"Discord recommends you start with {recommended_shards} shard{'s' if recommended_shards != 1 else ''} instead of {self.total_shards}" ) - logger.debug(f"Starting bot with {self.total_shards} shard{'s' if self.total_shards != 1 else ''}") + self.logger.debug(f"Starting bot with {self.total_shards} shard{'s' if self.total_shards != 1 else ''}") self._connection_states: list[ConnectionState] = [ ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards) ] diff --git a/naff/client/client.py b/naff/client/client.py index 04a320e8a..5bdaad6cb 100644 --- a/naff/client/client.py +++ b/naff/client/client.py @@ -33,13 +33,12 @@ import naff.api.events as events import naff.client.const as constants -from naff.models.naff.context import SendableContext from naff.api.events import MessageCreate, RawGatewayEvent, processors, Component, BaseEvent from naff.api.gateway.gateway import GatewayClient from naff.api.gateway.state import ConnectionState from naff.api.http.http_client import HTTPClient from naff.client import errors -from naff.client.const import GLOBAL_SCOPE, MISSING, MENTION_PREFIX, Absent, EMBED_MAX_DESC_LENGTH, logger +from naff.client.const import GLOBAL_SCOPE, MISSING, MENTION_PREFIX, Absent, EMBED_MAX_DESC_LENGTH, get_logger from naff.client.errors import ( BotException, ExtensionLoadException, @@ -50,6 +49,7 @@ NotFound, ) from naff.client.smart_cache import GlobalCache +from naff.client.utils import NullCache from naff.client.utils.input_utils import get_first_word, get_args from naff.client.utils.misc_utils import get_event_name, wrap_partial from naff.client.utils.serializer import to_image_data @@ -82,7 +82,6 @@ AutocompleteContext, HybridContext, ComponentCommand, - Context, application_commands_to_dict, sync_needed, VoiceRegion, @@ -96,7 +95,7 @@ from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.discord.modal import Modal from naff.models.naff.active_voice_state import ActiveVoiceState -from naff.models.naff.application_commands import ModalCommand +from naff.models.naff.application_commands import ContextMenu, ModalCommand from naff.models.naff.auto_defer import AutoDefer from naff.models.naff.hybrid_commands import _prefixed_from_slash, _base_subcommand_generator from naff.models.naff.listener import Listener @@ -185,6 +184,7 @@ class Client( processors.AutoModEvents, processors.ChannelEvents, processors.GuildEvents, + processors.IntegrationEvents, processors.MemberEvents, processors.MessageEvents, processors.ReactionEvents, @@ -224,6 +224,7 @@ class Client( shard_id: The zero based int ID of this shard debug_scope: Force all application commands to be registered within this scope + disable_dm_commands: Should interaction commands be disabled in DMs? basic_logging: Utilise basic logging to output library data to console. Do not use in combination with `Client.logger` logging_level: The level of logging to use for basic_logging. Do not use in combination with `Client.logger` logger: The logger NAFF should use. Do not use in combination with `Client.basic_logging` and `Client.logging_level`. Note: Different loggers with multiple clients are not supported @@ -251,6 +252,7 @@ def __init__( debug_scope: Absent["Snowflake_Type"] = MISSING, default_prefix: str | Iterable[str] = MENTION_PREFIX, delete_unused_application_cmds: bool = False, + disable_dm_commands: bool = False, enforce_interaction_perms: bool = True, fetch_members: bool = False, generate_prefixes: Absent[Callable[..., Coroutine]] = MISSING, @@ -258,7 +260,7 @@ def __init__( global_pre_run_callback: Absent[Callable[..., Coroutine]] = MISSING, intents: Union[int, Intents] = Intents.DEFAULT, interaction_context: Type[InteractionContext] = InteractionContext, - logger: logging.Logger = logger, + logger: logging.Logger = MISSING, owner_ids: Iterable["Snowflake_Type"] = (), modal_context: Type[ModalContext] = ModalContext, prefixed_context: Type[PrefixedContext] = PrefixedContext, @@ -273,6 +275,9 @@ def __init__( logging_level: int = logging.INFO, **kwargs, ) -> None: + if logger is MISSING: + logger = constants.get_logger() + if basic_logging: logging.basicConfig() logger.setLevel(logging_level) @@ -282,7 +287,7 @@ def __init__( """The logger NAFF should use. Do not use in combination with `Client.basic_logging` and `Client.logging_level`. !!! note Different loggers with multiple clients are not supported""" - constants.logger = logger + constants._logger = logger # Configuration self.sync_interactions: bool = sync_interactions @@ -309,7 +314,7 @@ def __init__( # resources - self.http: HTTPClient = HTTPClient() + self.http: HTTPClient = HTTPClient(logger=self.logger) """The HTTP client to use when interacting with discord endpoints""" # context objects @@ -330,6 +335,7 @@ def __init__( self._ready = asyncio.Event() self._closed = False self._startup = False + self.disable_dm_commands = disable_dm_commands self._guild_event = asyncio.Event() self.guild_event_timeout = 3 @@ -363,6 +369,10 @@ def __init__( """A dictionary of registered prefixed commands: `{name: command}`""" self.interactions: Dict["Snowflake_Type", Dict[str, InteractionCommand]] = {} """A dictionary of registered application commands: `{cmd_id: command}`""" + self.interaction_tree: Dict[ + "Snowflake_Type", Dict[str, InteractionCommand | Dict[str, InteractionCommand]] + ] = {} + """A dictionary of registered application commands in a tree""" self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {} self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {} self._interaction_scopes: Dict["Snowflake_Type", "Snowflake_Type"] = {} @@ -370,7 +380,7 @@ def __init__( self.__modules = {} self.ext = {} """A dictionary of mounted ext""" - self.listeners: Dict[str, List] = {} + self.listeners: Dict[str, list[Listener]] = {} self.waits: Dict[str, List] = {} self.owner_ids: set[Snowflake_Type] = set(owner_ids) @@ -489,7 +499,7 @@ def get_guild_websocket(self, id: "Snowflake_Type") -> GatewayClient: def _sanity_check(self) -> None: """Checks for possible and common errors in the bot's configuration.""" - logger.debug("Running client sanity checks...") + self.logger.debug("Running client sanity checks...") contexts = { self.interaction_context: InteractionContext, self.prefixed_context: PrefixedContext, @@ -503,20 +513,30 @@ def _sanity_check(self) -> None: raise TypeError(f"{obj.__name__} must inherit from {expected.__name__}") if self.del_unused_app_cmd: - logger.warning( + self.logger.warning( "As `delete_unused_application_cmds` is enabled, the client must cache all guilds app-commands, this could take a while." ) if Intents.GUILDS not in self._connection_state.intents: - logger.warning("GUILD intent has not been enabled; this is very likely to cause errors") + self.logger.warning("GUILD intent has not been enabled; this is very likely to cause errors") if self.fetch_members and Intents.GUILD_MEMBERS not in self._connection_state.intents: raise BotException("Members Intent must be enabled in order to use fetch members") elif self.fetch_members: - logger.warning("fetch_members enabled; startup will be delayed") + self.logger.warning("fetch_members enabled; startup will be delayed") if len(self.processors) == 0: - logger.warning("No Processors are loaded! This means no events will be processed!") + self.logger.warning("No Processors are loaded! This means no events will be processed!") + + caches = [ + c[0] + for c in inspect.getmembers(self.cache, predicate=lambda x: isinstance(x, dict)) + if not c[0].startswith("__") + ] + for cache in caches: + _cache_obj = getattr(self.cache, cache) + if isinstance(_cache_obj, NullCache): + self.logger.warning(f"{cache} has been disabled") async def generate_prefixes(self, bot: "Client", message: Message) -> str | Iterable[str]: """ @@ -548,6 +568,10 @@ async def generate_prefixes(bot, message): def _queue_task(self, coro: Listener, event: BaseEvent, *args, **kwargs) -> asyncio.Task: async def _async_wrap(_coro: Listener, _event: BaseEvent, *_args, **_kwargs) -> None: try: + if not isinstance(_event, (events.Error, events.RawGatewayEvent)): + if coro.delay_until_ready and not self.is_ready: + await self.wait_until_ready() + if len(_event.__attrs_attrs__) == 2: # override_name & bot await _coro() @@ -560,7 +584,7 @@ async def _async_wrap(_coro: Listener, _event: BaseEvent, *_args, **_kwargs) -> # No infinite loops please self.default_error_handler(repr(event), e) else: - self.dispatch(events.Error(repr(event), e)) + self.dispatch(events.Error(source=repr(event), error=e)) wrapped = _async_wrap(coro, event, *args, **kwargs) @@ -581,71 +605,76 @@ def default_error_handler(source: str, error: BaseException) -> None: if isinstance(error, HTTPException): # HTTPException's are of 3 known formats, we can parse them for human readable errors try: - errors = error.search_for_message(error.errors) - out = f"HTTPException: {error.status}|{error.response.reason}: " + "\n".join(errors) + out = [str(error)] except Exception: # noqa : S110 pass - logger.error( + get_logger().error( "Ignoring exception in {}:{}{}".format(source, "\n" if len(out) > 1 else " ", "".join(out)), ) - @Listener.create() - async def _on_error(self, event: events.Error) -> None: - await self.on_error(event.source, event.error, *event.args, **event.kwargs) - - async def on_error(self, source: str, error: Exception, *args, **kwargs) -> None: + @Listener.create(is_default_listener=True) + async def on_error(self, event: events.Error) -> None: """ Catches all errors dispatched by the library. - By default it will format and print them to console + By default it will format and print them to console. - Override this to change error handling behaviour + Listen to the `Error` event to overwrite this behaviour. """ - self.default_error_handler(source, error) + self.default_error_handler(event.source, event.error) - async def on_command_error(self, ctx: SendableContext, error: Exception, *args, **kwargs) -> None: + @Listener.create(is_default_listener=True) + async def on_command_error(self, event: events.CommandError) -> None: """ Catches all errors dispatched by commands. - By default it will call `Client.on_error` + By default it will dispatch the `Error` event. - Override this to change error handling behavior + Listen to the `CommandError` event to overwrite this behaviour. """ - self.dispatch(events.Error(f"cmd /`{ctx.invoke_target}`", error, args, kwargs, ctx)) + self.dispatch( + events.Error( + source=f"cmd `/{event.ctx.invoke_target}`", + error=event.error, + args=event.args, + kwargs=event.kwargs, + ctx=event.ctx, + ) + ) try: - if isinstance(error, errors.CommandOnCooldown): - await ctx.send( + if isinstance(event.error, errors.CommandOnCooldown): + await event.ctx.send( embeds=Embed( description=f"This command is on cooldown!\n" - f"Please try again in {int(error.cooldown.get_cooldown_time())} seconds", + f"Please try again in {int(event.error.cooldown.get_cooldown_time())} seconds", color=BrandColors.FUCHSIA, ) ) - elif isinstance(error, errors.MaxConcurrencyReached): - await ctx.send( + elif isinstance(event.error, errors.MaxConcurrencyReached): + await event.ctx.send( embeds=Embed( description="This command has reached its maximum concurrent usage!\n" "Please try again shortly.", color=BrandColors.FUCHSIA, ) ) - elif isinstance(error, errors.CommandCheckFailure): - await ctx.send( + elif isinstance(event.error, errors.CommandCheckFailure): + await event.ctx.send( embeds=Embed( description="You do not have permission to run this command!", color=BrandColors.YELLOW, ) ) elif self.send_command_tracebacks: - out = "".join(traceback.format_exception(error)) + out = "".join(traceback.format_exception(event.error)) if self.http.token is not None: out = out.replace(self.http.token, "[REDACTED TOKEN]") - await ctx.send( + await event.ctx.send( embeds=Embed( - title=f"Error: {type(error).__name__}", + title=f"Error: {type(event.error).__name__}", color=BrandColors.RED, description=f"```\n{out[:EMBED_MAX_DESC_LENGTH-8]}```", ) @@ -653,85 +682,135 @@ async def on_command_error(self, ctx: SendableContext, error: Exception, *args, except errors.NaffException: pass - async def on_command(self, ctx: Context) -> None: + @Listener.create(is_default_listener=True) + async def on_command_completion(self, event: events.CommandCompletion) -> None: """ Called *after* any command is ran. - By default, it will simply log the command, override this to change that behaviour + By default, it will simply log the command. - Args: - ctx: The context of the command that was called + Listen to the `CommandCompletion` event to overwrite this behaviour. """ - if isinstance(ctx, PrefixedContext): + if isinstance(event.ctx, PrefixedContext): symbol = "@" - elif isinstance(ctx, InteractionContext): + elif isinstance(event.ctx, InteractionContext): symbol = "/" + elif isinstance(event.ctx, HybridContext): + symbol = "@/" else: symbol = "?" # likely custom context - logger.info(f"Command Called: {symbol}{ctx.invoke_target} with {ctx.args = } | {ctx.kwargs = }") + self.logger.info( + f"Command Called: {symbol}{event.ctx.invoke_target} with {event.ctx.args = } | {event.ctx.kwargs = }" + ) - async def on_component_error(self, ctx: ComponentContext, error: Exception, *args, **kwargs) -> None: + @Listener.create(is_default_listener=True) + async def on_component_error(self, event: events.ComponentError) -> None: """ Catches all errors dispatched by components. - By default it will call `Naff.on_error` + By default it will dispatch the `Error` event. - Override this to change error handling behavior + Listen to the `ComponentError` event to overwrite this behaviour. """ - return self.dispatch(events.Error(f"Component Callback for {ctx.custom_id}", error, args, kwargs, ctx)) + self.dispatch( + events.Error( + source=f"Component Callback for {event.ctx.custom_id}", + error=event.error, + args=event.args, + kwargs=event.kwargs, + ctx=event.ctx, + ) + ) - async def on_component(self, ctx: ComponentContext) -> None: + @Listener.create(is_default_listener=True) + async def on_component_completion(self, event: events.ComponentCompletion) -> None: """ Called *after* any component callback is ran. - By default, it will simply log the component use, override this to change that behaviour + By default, it will simply log the component use. - Args: - ctx: The context of the component that was called + Listen to the `ComponentCompletion` event to overwrite this behaviour. """ symbol = "Âĸ" - logger.info(f"Component Called: {symbol}{ctx.invoke_target} with {ctx.args = } | {ctx.kwargs = }") + self.logger.info( + f"Component Called: {symbol}{event.ctx.invoke_target} with {event.ctx.args = } | {event.ctx.kwargs = }" + ) - async def on_autocomplete_error(self, ctx: AutocompleteContext, error: Exception, *args, **kwargs) -> None: + @Listener.create(is_default_listener=True) + async def on_autocomplete_error(self, event: events.AutocompleteError) -> None: """ Catches all errors dispatched by autocompletion options. - By default it will call `Naff.on_error` + By default it will dispatch the `Error` event. - Override this to change error handling behavior + Listen to the `AutocompleteError` event to overwrite this behaviour. """ - return self.dispatch( + self.dispatch( events.Error( - f"Autocomplete Callback for /{ctx.invoke_target} - Option: {ctx.focussed_option}", - error, - args, - kwargs, - ctx, + source=f"Autocomplete Callback for /{event.ctx.invoke_target} - Option: {event.ctx.focussed_option}", + error=event.error, + args=event.args, + kwargs=event.kwargs, + ctx=event.ctx, ) ) - async def on_autocomplete(self, ctx: AutocompleteContext) -> None: + @Listener.create(is_default_listener=True) + async def on_autocomplete_completion(self, event: events.AutocompleteCompletion) -> None: """ Called *after* any autocomplete callback is ran. - By default, it will simply log the autocomplete callback, override this to change that behaviour + By default, it will simply log the autocomplete callback. - Args: - ctx: The context of the command that was called + Listen to the `AutocompleteCompletion` event to overwrite this behaviour. """ symbol = "$" - logger.info(f"Autocomplete Called: {symbol}{ctx.invoke_target} with {ctx.args = } | {ctx.kwargs = }") + self.logger.info( + f"Autocomplete Called: {symbol}{event.ctx.invoke_target} with {event.ctx.focussed_option = } | {event.ctx.kwargs = }" + ) + + @Listener.create(is_default_listener=True) + async def on_modal_error(self, event: events.ModalError) -> None: + """ + Catches all errors dispatched by modals. + + By default it will dispatch the `Error` event. + + Listen to the `ModalError` event to overwrite this behaviour. + + """ + self.dispatch( + events.Error( + source=f"Modal Callback for custom_id {event.ctx.custom_id}", + error=event.error, + args=event.args, + kwargs=event.kwargs, + ctx=event.ctx, + ) + ) + + @Listener.create(is_default_listener=True) + async def on_modal_completion(self, event: events.ModalCompletion) -> None: + """ + Called *after* any modal callback is ran. + + By default, it will simply log the modal callback. + + Listen to the `ModalCompletion` event to overwrite this behaviour. + + """ + self.logger.info(f"Modal Called: {event.ctx.custom_id = } with {event.ctx.responses = }") @Listener.create() async def on_resume(self) -> None: self._ready.set() - @Listener.create() + @Listener.create(is_default_listener=True) async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: """ Catches websocket ready and determines when to dispatch the client `READY` signal. @@ -749,7 +828,10 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: try: # wait to let guilds cache await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) except asyncio.TimeoutError: - logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") + # this will *mostly* occur when a guild has been shadow deleted by discord T&S. + # there is no way to check for this, so we just need to wait for this to time out. + # We still log it though, just in case. + self.logger.debug("Timeout waiting for guilds cache") break self._guild_event.clear() @@ -761,16 +843,9 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: # ensure all guilds have completed chunking for guild in self.guilds: if guild and not guild.chunked.is_set(): - logger.debug(f"Waiting for {guild.id} to chunk") + self.logger.debug(f"Waiting for {guild.id} to chunk") await guild.chunked.wait() - # run any pending startup tasks - if self.async_startup_tasks: - try: - await asyncio.gather(*self.async_startup_tasks) - except Exception as e: - self.dispatch(events.Error("async-extension-loader", e)) - # cache slash commands if not self._startup: await self._init_interactions() @@ -817,7 +892,7 @@ async def login(self, token) -> None: # so im gathering commands here self._gather_commands() - logger.debug("Attempting to login") + self.logger.debug("Attempting to login") me = await self.http.login(token.strip()) self._user = NaffUser.from_dict(me, self) self.cache.place_user_data(me) @@ -837,6 +912,13 @@ async def astart(self, token: str) -> None: token: Your bot's token """ await self.login(token) + + # run any pending startup tasks + if self.async_startup_tasks: + try: + await asyncio.gather(*self.async_startup_tasks) + except Exception as e: + self.dispatch(events.Error(source="async-extension-loader", error=e)) try: await self._connection_state.start() finally: @@ -865,7 +947,7 @@ async def start_gateway(self) -> None: async def stop(self) -> None: """Shutdown the bot.""" - logger.debug("Stopping the bot.") + self.logger.debug("Stopping the bot.") self._ready.clear() await self.http.close() await self._connection_state.stop() @@ -880,7 +962,7 @@ def dispatch(self, event: events.BaseEvent, *args, **kwargs) -> None: """ listeners = self.listeners.get(event.resolved_name, []) if listeners: - logger.debug(f"Dispatching Event: {event.resolved_name}") + self.logger.debug(f"Dispatching Event: {event.resolved_name}") event.bot = self for _listen in listeners: try: @@ -957,14 +1039,14 @@ async def wait_for_modal( author = to_snowflake(author) if author else None def predicate(event) -> bool: - if modal.custom_id != event.context.custom_id: + if modal.custom_id != event.ctx.custom_id: return False - if author and author != to_snowflake(event.context.author): + if author and author != to_snowflake(event.ctx.author): return False return True - resp = await self.wait_for("modal_response", predicate, timeout) - return resp.context + resp = await self.wait_for("modal_completion", predicate, timeout) + return resp.ctx async def wait_for_component( self, @@ -985,7 +1067,7 @@ async def wait_for_component( timeout: The number of seconds to wait before timing out. Returns: - `Component` that was invoked. Use `.context` to get the `ComponentContext`. + `Component` that was invoked. Use `.ctx` to get the `ComponentContext`. Raises: asyncio.TimeoutError: if timed out @@ -1004,7 +1086,7 @@ async def wait_for_component( custom_ids = [str(i) for i in custom_ids] def _check(event: Component) -> bool: - ctx: ComponentContext = event.context + ctx: ComponentContext = event.ctx # if custom_ids is empty or there is a match wanted_message = not message_ids or ctx.message.id in ( [message_ids] if isinstance(message_ids, int) else message_ids @@ -1068,19 +1150,30 @@ def add_listener(self, listener: Listener) -> None: listener Listener: The listener to add to the client """ - # check that the required intents are enabled - event_class_name = "".join([name.capitalize() for name in listener.event.split("_")]) - if event_class := globals().get(event_class_name): - if required_intents := _INTENT_EVENTS.get(event_class): # noqa - if not any(required_intent in self.intents for required_intent in required_intents): - self.logger.warning( - f"Event `{listener.event}` will not work since the required intent is not set -> Requires any of: `{required_intents}`" - ) + if not listener.is_default_listener: + # check that the required intents are enabled + + event_class_name = "".join([name.capitalize() for name in listener.event.split("_")]) + if event_class := globals().get(event_class_name): + if required_intents := _INTENT_EVENTS.get(event_class): # noqa + if not any(required_intent in self.intents for required_intent in required_intents): + self.logger.warning( + f"Event `{listener.event}` will not work since the required intent is not set -> Requires any of: `{required_intents}`" + ) if listener.event not in self.listeners: self.listeners[listener.event] = [] self.listeners[listener.event].append(listener) + # check if other listeners are to be deleted + default_listeners = [c_listener.is_default_listener for c_listener in self.listeners[listener.event]] + removes_defaults = [c_listener.disable_default_listeners for c_listener in self.listeners[listener.event]] + + if any(default_listeners) and any(removes_defaults): + self.listeners[listener.event] = [ + c_listener for c_listener in self.listeners[listener.event] if not c_listener.is_default_listener + ] + def add_interaction(self, command: InteractionCommand) -> bool: """ Add a slash command to the client. @@ -1092,10 +1185,15 @@ def add_interaction(self, command: InteractionCommand) -> bool: if self.debug_scope: command.scopes = [self.debug_scope] + if self.disable_dm_commands: + command.dm_permission = False + # for SlashCommand objs without callback (like objects made to hold group info etc) if command.callback is None: return False + base, group, sub, *_ = command.resolved_name.split(" ") + [None, None] + for scope in command.scopes: if scope not in self.interactions: self.interactions[scope] = {} @@ -1108,6 +1206,23 @@ def add_interaction(self, command: InteractionCommand) -> bool: self.interactions[scope][command.resolved_name] = command + if scope not in self.interaction_tree: + self.interaction_tree[scope] = {} + + if group is None or isinstance(command, ContextMenu): + self.interaction_tree[scope][command.resolved_name] = command + elif group is not None: + if not (current := self.interaction_tree[scope].get(base)) or isinstance(current, SlashCommand): + self.interaction_tree[scope][base] = {} + if sub is None: + self.interaction_tree[scope][base][group] = command + else: + if not (current := self.interaction_tree[scope][base].get(group)) or isinstance( + current, SlashCommand + ): + self.interaction_tree[scope][base][group] = {} + self.interaction_tree[scope][base][group][sub] = command + return True def add_hybrid_command(self, command: HybridCommand) -> bool: @@ -1121,7 +1236,7 @@ def add_hybrid_command(self, command: HybridCommand) -> bool: prefixed_base = self.prefixed_commands.get(str(command.name)) if not prefixed_base: prefixed_base = _base_subcommand_generator( - str(command.name), list((command.name.to_locale_dict() or {}).values()), str(command.description) + str(command.name), list(command.name.to_locale_dict().values()), str(command.description) ) self.add_prefixed_command(prefixed_base) @@ -1132,7 +1247,7 @@ def add_hybrid_command(self, command: HybridCommand) -> bool: if not prefixed_base: prefixed_base = _base_subcommand_generator( str(command.group_name), - list((command.group_name.to_locale_dict() or {}).values()), + list(command.group_name.to_locale_dict().values()), str(command.group_description), group=True, ) @@ -1228,7 +1343,7 @@ def process(_cmds) -> None: elif isinstance(func, Listener): self.add_listener(func) - logger.debug(f"{len(_cmds)} commands have been loaded from `__main__` and `client`") + self.logger.debug(f"{len(_cmds)} commands have been loaded from `__main__` and `client`") process( [obj for _, obj in inspect.getmembers(sys.modules["__main__"]) if isinstance(obj, (BaseCommand, Listener))] @@ -1259,7 +1374,7 @@ async def _init_interactions(self) -> None: else: await self._cache_interactions(warn_missing=False) except Exception as e: - self.dispatch(events.Error("Interaction Syncing", e)) + self.dispatch(events.Error(source="Interaction Syncing", error=e)) async def _cache_interactions(self, warn_missing: bool = False) -> None: """Get all interactions used by this bot and cache them.""" @@ -1285,7 +1400,7 @@ async def wrap(*args, **kwargs) -> Absent[List[Dict]]: for scope, remote_cmds in results.items(): if remote_cmds == MISSING: - logger.debug(f"Bot was not invited to guild {scope} with `application.commands` scope") + self.logger.debug(f"Bot was not invited to guild {scope} with `application.commands` scope") continue remote_cmds = {cmd_data["name"]: cmd_data for cmd_data in remote_cmds} @@ -1297,7 +1412,7 @@ async def wrap(*args, **kwargs) -> Absent[List[Dict]]: if cmd_data is MISSING: if cmd_name not in found: if warn_missing: - logger.error( + self.logger.error( f'Detected yet to sync slash command "/{cmd_name}" for scope ' f"{'global' if scope == GLOBAL_SCOPE else scope}" ) @@ -1309,7 +1424,7 @@ async def wrap(*args, **kwargs) -> Absent[List[Dict]]: if warn_missing: for cmd_data in remote_cmds.values(): - logger.error( + self.logger.error( f"Detected unimplemented slash command \"/{cmd_data['name']}\" for scope " f"{'global' if scope == GLOBAL_SCOPE else scope}" ) @@ -1337,7 +1452,7 @@ async def synchronise_interactions( # if we're not deleting, just check the scopes we have cmds registered in cmd_scopes = list(set(self.interactions) | {GLOBAL_SCOPE}) - local_cmds_json = application_commands_to_dict(self.interactions) + local_cmds_json = application_commands_to_dict(self.interactions, self) async def sync_scope(cmd_scope) -> None: @@ -1348,7 +1463,7 @@ async def sync_scope(cmd_scope) -> None: try: remote_commands = await self.http.get_application_commands(self.app.id, cmd_scope) except Forbidden: - logger.warning(f"Bot is lacking `application.commands` scope in {cmd_scope}!") + self.logger.warning(f"Bot is lacking `application.commands` scope in {cmd_scope}!") return for local_cmd in self.interactions.get(cmd_scope, {}).values(): @@ -1378,13 +1493,13 @@ async def sync_scope(cmd_scope) -> None: if sync_needed_flag or (_delete_cmds and len(sync_payload) < len(remote_commands)): # synchronise commands if flag is set, or commands are to be deleted - logger.info(f"Overwriting {cmd_scope} with {len(sync_payload)} application commands") + self.logger.info(f"Overwriting {cmd_scope} with {len(sync_payload)} application commands") sync_response: list[dict] = await self.http.overwrite_application_commands( self.app.id, sync_payload, cmd_scope ) self._cache_sync_response(sync_response, cmd_scope) else: - logger.debug(f"{cmd_scope} is already up-to-date with {len(remote_commands)} commands.") + self.logger.debug(f"{cmd_scope} is already up-to-date with {len(remote_commands)} commands.") except Forbidden as e: raise InteractionMissingAccess(cmd_scope) from e @@ -1394,7 +1509,7 @@ async def sync_scope(cmd_scope) -> None: await asyncio.gather(*[sync_scope(scope) for scope in cmd_scopes]) t = time.perf_counter() - s - logger.debug(f"Sync of {len(cmd_scopes)} scopes took {t} seconds") + self.logger.debug(f"Sync of {len(cmd_scopes)} scopes took {t} seconds") def get_application_cmd_by_id(self, cmd_id: "Snowflake_Type") -> Optional[InteractionCommand]: """ @@ -1415,8 +1530,7 @@ def get_application_cmd_by_id(self, cmd_id: "Snowflake_Type") -> Optional[Intera return cmd return None - @staticmethod - def _raise_sync_exception(e: HTTPException, cmds_json: dict, cmd_scope: "Snowflake_Type") -> NoReturn: + def _raise_sync_exception(self, e: HTTPException, cmds_json: dict, cmd_scope: "Snowflake_Type") -> NoReturn: try: if isinstance(e.errors, dict): for cmd_num in e.errors.keys(): @@ -1424,9 +1538,9 @@ def _raise_sync_exception(e: HTTPException, cmds_json: dict, cmd_scope: "Snowfla output = e.search_for_message(e.errors[cmd_num], cmd) if len(output) > 1: output = "\n".join(output) - logger.error(f"Multiple Errors found in command `{cmd['name']}`:\n{output}") + self.logger.error(f"Multiple Errors found in command `{cmd['name']}`:\n{output}") else: - logger.error(f"Error in command `{cmd['name']}`: {output[0]}") + self.logger.error(f"Error in command `{cmd['name']}`: {output[0]}") else: raise e from None except Exception: @@ -1564,7 +1678,7 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: ctx = await self.get_context(interaction_data, True) ctx.command: SlashCommand = self.interactions[scope][ctx.invoke_target] # type: ignore - logger.debug(f"{scope} :: {ctx.command.name} should be called") + self.logger.debug(f"{scope} :: {ctx.command.name} should be called") if ctx.command.auto_defer: auto_defer = ctx.command.auto_defer @@ -1577,9 +1691,9 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: try: await ctx.command.autocomplete_callbacks[auto_opt](ctx, **ctx.kwargs) except Exception as e: - await self.on_autocomplete_error(ctx, e) + self.dispatch(events.AutocompleteError(ctx=ctx, error=e)) finally: - await self.on_autocomplete(ctx) + self.dispatch(events.AutocompleteCompletion(ctx=ctx)) else: try: await auto_defer(ctx) @@ -1589,18 +1703,18 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: if self.post_run_callback: await self.post_run_callback(ctx, **ctx.kwargs) except Exception as e: - await self.on_command_error(ctx, e) + self.dispatch(events.CommandError(ctx=ctx, error=e)) finally: - await self.on_command(ctx) + self.dispatch(events.CommandCompletion(ctx=ctx)) else: - logger.error(f"Unknown cmd_id received:: {interaction_id} ({name})") + self.logger.error(f"Unknown cmd_id received:: {interaction_id} ({name})") elif interaction_data["type"] == InteractionTypes.MESSAGE_COMPONENT: # Buttons, Selects, ContextMenu::Message ctx = await self.get_context(interaction_data, True) component_type = interaction_data["data"]["component_type"] - self.dispatch(events.Component(ctx)) + self.dispatch(events.Component(ctx=ctx)) if callback := self._component_callbacks.get(ctx.custom_id): ctx.command = callback try: @@ -1610,17 +1724,17 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: if self.post_run_callback: await self.post_run_callback(ctx) except Exception as e: - await self.on_component_error(ctx, e) + self.dispatch(events.ComponentError(ctx=ctx, error=e)) finally: - await self.on_component(ctx) + self.dispatch(events.ComponentCompletion(ctx=ctx)) if component_type == ComponentTypes.BUTTON: - self.dispatch(events.Button(ctx)) - if component_type == ComponentTypes.SELECT: + self.dispatch(events.ButtonPressed(ctx)) + if component_type == ComponentTypes.STRING_SELECT: self.dispatch(events.Select(ctx)) elif interaction_data["type"] == InteractionTypes.MODAL_RESPONSE: ctx = await self.get_context(interaction_data, True) - self.dispatch(events.ModalResponse(ctx)) + self.dispatch(events.ModalCompletion(ctx=ctx)) # todo: Polls remove this icky code duplication - love from past-polls ❤ī¸ if callback := self._modal_callbacks.get(ctx.custom_id): @@ -1633,14 +1747,12 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: if self.post_run_callback: await self.post_run_callback(ctx) except Exception as e: - await self.on_component_error(ctx, e) - finally: - await self.on_component(ctx) + self.dispatch(events.ModalError(ctx=ctx, error=e)) else: raise NotImplementedError(f"Unknown Interaction Received: {interaction_data['type']}") - @Listener.create("message_create") + @Listener.create("message_create", is_default_listener=True) async def _dispatch_prefixed_commands(self, event: MessageCreate) -> None: """Determine if a prefixed command is being triggered, and dispatch it.""" message = event.message @@ -1701,9 +1813,9 @@ async def _dispatch_prefixed_commands(self, event: MessageCreate) -> None: if new_command.error_callback: await new_command.error_callback(e, context) elif new_command.extension and new_command.extension.extension_error: - await new_command.extension.extension_error(context) + await new_command.extension.extension_error(e, context) else: - await self.on_command_error(context, e) + self.dispatch(events.CommandError(ctx=context, error=e)) return if not isinstance(command, PrefixedCommand): @@ -1712,9 +1824,7 @@ async def _dispatch_prefixed_commands(self, event: MessageCreate) -> None: if command and command.enabled: # yeah, this looks ugly context.command = command - context.invoke_target = ( - message.content.removeprefix(prefix_used).removesuffix(content_parameters).strip() # type: ignore - ) + context.invoke_target = message.content.removeprefix(prefix_used).removesuffix(content_parameters).strip() # type: ignore context.args = get_args(context.content_parameters) try: if self.pre_run_callback: @@ -1723,11 +1833,11 @@ async def _dispatch_prefixed_commands(self, event: MessageCreate) -> None: if self.post_run_callback: await self.post_run_callback(context) except Exception as e: - await self.on_command_error(context, e) + self.dispatch(events.CommandError(ctx=context, error=e)) finally: - await self.on_command(context) + self.dispatch(events.CommandCompletion(ctx=context)) - @Listener.create("disconnect") + @Listener.create("disconnect", is_default_listener=True) async def _disconnect(self) -> None: self._ready.clear() @@ -1770,27 +1880,37 @@ def load_extension(self, name: str, package: str | None = None, **load_kwargs: A **load_kwargs: The auto-filled mapping of the load keyword arguments """ - name = importlib.util.resolve_name(name, package) - if name in self.__modules: - raise Exception(f"{name} already loaded") + module_name = importlib.util.resolve_name(name, package) + if module_name in self.__modules: + raise Exception(f"{module_name} already loaded") - module = importlib.import_module(name, package) + module = importlib.import_module(module_name, package) try: setup = getattr(module, "setup", None) - if not setup: - raise ExtensionLoadException( - f"{name} lacks an entry point. Ensure you have a function called `setup` defined in that file" - ) from None - setup(self, **load_kwargs) + if setup: + setup(self, **load_kwargs) + else: + self.logger.debug("No setup function found in %s", module_name) + + found = False + objects = {name: obj for name, obj in inspect.getmembers(module) if isinstance(obj, type)} + for obj_name, obj in objects.items(): + if Extension in obj.__bases__: + self.logger.debug(f"Found extension class {obj_name} in {module_name}: Attempting to load") + obj(self, **load_kwargs) + found = True + if not found: + raise Exception(f"{module_name} contains no Extensions") + except ExtensionLoadException: raise except Exception as e: - del sys.modules[name] - raise ExtensionLoadException(f"Unexpected Error loading {name}") from e + del sys.modules[module_name] + raise ExtensionLoadException(f"Unexpected Error loading {module_name}") from e else: - logger.debug(f"Loaded Extension: {name}") - self.__modules[name] = module + self.logger.debug(f"Loaded Extension: {module_name}") + self.__modules[module_name] = module if self.sync_ext and self._ready.is_set(): try: @@ -1857,7 +1977,7 @@ def reload_extension( module = self.__modules.get(name) if module is None: - logger.warning("Attempted to reload extension thats not loaded. Loading extension instead") + self.logger.warning("Attempted to reload extension thats not loaded. Loading extension instead") return self.load_extension(name, package) if not load_kwargs: diff --git a/naff/client/const.py b/naff/client/const.py index 0fed26a71..d70a67778 100644 --- a/naff/client/const.py +++ b/naff/client/const.py @@ -45,7 +45,7 @@ "__repo_url__", "__py_version__", "__api_version__", - "logger", + "get_logger", "logger_name", "kwarg_spam", "DISCORD_EPOCH", @@ -86,7 +86,14 @@ __py_version__ = f"{_ver_info[0]}.{_ver_info[1]}" __api_version__ = 10 logger_name = "naff" -logger = logging.getLogger(logger_name) +_logger = logging.getLogger(logger_name) + + +def get_logger() -> logging.Logger: + global _logger + return _logger + + default_locale = "english_us" kwarg_spam = False diff --git a/naff/client/errors.py b/naff/client/errors.py index 467c5d743..e9d5ce5fd 100644 --- a/naff/client/errors.py +++ b/naff/client/errors.py @@ -2,8 +2,8 @@ import aiohttp -from . import const from naff.client.utils.misc_utils import escape_mentions +from . import const if TYPE_CHECKING: from naff.models.naff.command import BaseCommand @@ -115,6 +115,9 @@ def __str__(self) -> str: out = f"HTTPException: {self.status}|{self.response.reason} || {self.text}" return out + def __repr__(self) -> str: + return str(self) + @staticmethod def search_for_message(errors: dict, lookup: Optional[dict] = None) -> list[str]: """ @@ -315,7 +318,7 @@ class CommandCheckFailure(CommandException): def __init__(self, command: "BaseCommand", check: Callable[..., Coroutine], context: "Context") -> None: self.command: "BaseCommand" = command self.check: Callable[..., Coroutine] = check - self.context = context + self.ctx = context class BadArgument(CommandException): @@ -347,7 +350,9 @@ class EphemeralEditException(MessageException): """ def __init__(self) -> None: - super().__init__("Ephemeral messages cannot be edited.") + super().__init__( + "Ephemeral messages can only be edited with component's `edit_origin` method or using InteractionContext" + ) class ThreadException(BotException): diff --git a/naff/client/mixins/send.py b/naff/client/mixins/send.py index 159efb744..de43f5d66 100644 --- a/naff/client/mixins/send.py +++ b/naff/client/mixins/send.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union import naff.models as models @@ -24,6 +24,7 @@ async def _send_http_request(self, message_payload: dict, files: Iterable["UPLOA async def send( self, content: Optional[str] = None, + *, embeds: Optional[Union[Iterable[Union["Embed", dict]], Union["Embed", dict]]] = None, embed: Optional[Union["Embed", dict]] = None, components: Optional[ @@ -43,7 +44,7 @@ async def send( suppress_embeds: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, delete_after: Optional[float] = None, - **kwargs, + **kwargs: Any, ) -> "Message": """ Send a message. diff --git a/naff/client/mixins/serialization.py b/naff/client/mixins/serialization.py index f1d483fcb..4971cb0d6 100644 --- a/naff/client/mixins/serialization.py +++ b/naff/client/mixins/serialization.py @@ -1,16 +1,18 @@ +from logging import Logger from typing import Any, Dict, List, Type import attrs import naff.client.const as const -from naff.client.utils.attr_utils import define import naff.client.utils.serializer as serializer __all__ = ("DictSerializationMixin",) -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class DictSerializationMixin: + logger: Logger = attrs.field(init=False, factory=const.get_logger, metadata=serializer.no_export_meta, repr=False) + @classmethod def _get_keys(cls) -> frozenset: if (keys := getattr(cls, "_keys", None)) is None: @@ -30,7 +32,7 @@ def _get_init_keys(cls) -> frozenset: def _filter_kwargs(cls, kwargs_dict: dict, keys: frozenset) -> dict: if const.kwarg_spam: unused = {k: v for k, v in kwargs_dict.items() if k not in keys} - const.logger.debug(f"Unused kwargs: {cls.__name__}: {unused}") # for debug + const.get_logger().debug(f"Unused kwargs: {cls.__name__}: {unused}") # for debug return {k: v for k, v in kwargs_dict.items() if k in keys} @classmethod @@ -91,7 +93,6 @@ def update_from_dict(self: Type[const.T], data: Dict[str, Any]) -> const.T: """ data = self._process_dict(data) for key, value in self._filter_kwargs(data, self._get_keys()).items(): - # todo improve setattr(self, key, value) return self diff --git a/naff/client/smart_cache.py b/naff/client/smart_cache.py index 201e1b2eb..571ebee88 100644 --- a/naff/client/smart_cache.py +++ b/naff/client/smart_cache.py @@ -1,12 +1,13 @@ from contextlib import suppress +from logging import Logger from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union +import attrs import discord_typings -from naff.client.const import MISSING, logger, Absent +from naff.client.const import Absent, MISSING, get_logger from naff.client.errors import NotFound, Forbidden -from naff.client.utils.attr_utils import define, field -from naff.client.utils.cache import TTLCache +from naff.client.utils.cache import TTLCache, NullCache from naff.models import VoiceState from naff.models.discord.channel import BaseChannel, GuildChannel, ThreadChannel from naff.models.discord.emoji import CustomEmoji @@ -28,7 +29,7 @@ def create_cache( ttl: Optional[int] = 60, hard_limit: Optional[int] = 250, soft_limit: Absent[Optional[int]] = MISSING -) -> Union[dict, TTLCache]: +) -> Union[dict, TTLCache, NullCache]: """ Create a cache object based on the parameters passed. @@ -45,39 +46,45 @@ def create_cache( """ if ttl is None and hard_limit is None: return {} + if ttl == 0 and hard_limit == 0 and soft_limit == 0: + return NullCache() else: if not soft_limit: soft_limit = int(hard_limit / 4) if hard_limit else 50 return TTLCache(hard_limit=hard_limit or float("inf"), soft_limit=soft_limit or 0, ttl=ttl or float("inf")) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class GlobalCache: - _client: "Client" = field() + _client: "Client" = attrs.field( + repr=False, + ) # Non expiring discord objects cache - user_cache: dict = field(factory=dict) # key: user_id - member_cache: dict = field(factory=dict) # key: (guild_id, user_id) - channel_cache: dict = field(factory=dict) # key: channel_id - guild_cache: dict = field(factory=dict) # key: guild_id + user_cache: dict = attrs.field(repr=False, factory=dict) # key: user_id + member_cache: dict = attrs.field(repr=False, factory=dict) # key: (guild_id, user_id) + channel_cache: dict = attrs.field(repr=False, factory=dict) # key: channel_id + guild_cache: dict = attrs.field(repr=False, factory=dict) # key: guild_id # Expiring discord objects cache - message_cache: TTLCache = field(factory=TTLCache) # key: (channel_id, message_id) - role_cache: TTLCache = field(factory=dict) # key: role_id - voice_state_cache: TTLCache = field(factory=dict) # key: user_id - bot_voice_state_cache: dict = field(factory=dict) # key: guild_id + message_cache: TTLCache = attrs.field(repr=False, factory=TTLCache) # key: (channel_id, message_id) + role_cache: TTLCache = attrs.field(repr=False, factory=dict) # key: role_id + voice_state_cache: TTLCache = attrs.field(repr=False, factory=dict) # key: user_id + bot_voice_state_cache: dict = attrs.field(repr=False, factory=dict) # key: guild_id - enable_emoji_cache: bool = field(default=False) + enable_emoji_cache: bool = attrs.field(repr=False, default=False) """If the emoji cache should be enabled. Default: False""" - emoji_cache: Optional[dict] = field(default=None, init=False) # key: emoji_id + emoji_cache: Optional[dict] = attrs.field(repr=False, default=None, init=False) # key: emoji_id # Expiring id reference cache - dm_channels: TTLCache = field(factory=TTLCache) # key: user_id - user_guilds: TTLCache = field(factory=dict) # key: user_id; value: set[guild_id] + dm_channels: TTLCache = attrs.field(repr=False, factory=TTLCache) # key: user_id + user_guilds: TTLCache = attrs.field(repr=False, factory=dict) # key: user_id; value: set[guild_id] + + logger: Logger = attrs.field(repr=False, init=False, factory=get_logger) def __attrs_post_init__(self) -> None: if not isinstance(self.message_cache, TTLCache): - logger.warning( + self.logger.warning( "Disabling cache limits for message_cache is not recommended! This can result in very high memory usage" ) @@ -446,7 +453,7 @@ async def fetch_channel( data = await self._client.http.get_channel(channel_id) channel = self.place_channel_data(data) except Forbidden: - logger.warning(f"Forbidden access to channel {channel_id}. Generating fallback channel object") + self.logger.warning(f"Forbidden access to channel {channel_id}. Generating fallback channel object") channel = BaseChannel.from_dict({"id": channel_id, "type": MISSING}, self._client) return channel diff --git a/naff/client/utils/__init__.py b/naff/client/utils/__init__.py index a3873878e..b28b4c02c 100644 --- a/naff/client/utils/__init__.py +++ b/naff/client/utils/__init__.py @@ -5,3 +5,4 @@ from .misc_utils import * from .serializer import * from .formatting import * +from .text_utils import * diff --git a/naff/client/utils/attr_utils.py b/naff/client/utils/attr_utils.py index 83b8bbad6..0ca019248 100644 --- a/naff/client/utils/attr_utils.py +++ b/naff/client/utils/attr_utils.py @@ -4,7 +4,7 @@ import attrs from attr import Attribute -from naff.client.const import MISSING, logger +from naff.client.const import MISSING, get_logger __all__ = ("define", "field", "docs", "str_validator") @@ -50,7 +50,7 @@ def str_validator(self: Any, attribute: attrs.Attribute, value: Any) -> None: if value is MISSING: return setattr(self, attribute.name, str(value)) - logger.warning( + get_logger().warning( f"Value of {attribute.name} has been automatically converted to a string. Please use strings in future.\n" "Note: Discord will always return value as a string" ) diff --git a/naff/client/utils/cache.py b/naff/client/utils/cache.py index 95b0ae0ea..e89c99ccb 100644 --- a/naff/client/utils/cache.py +++ b/naff/client/utils/cache.py @@ -5,18 +5,31 @@ import attrs -from naff.client.utils.attr_utils import define, field - -__all__ = ("TTLItem", "TTLCache") +__all__ = ("TTLItem", "TTLCache", "NullCache") KT = TypeVar("KT") VT = TypeVar("VT") -@define(kw_only=False) +class NullCache(dict): + """ + A special cache that will always return None + + Effectively just a lazy way to disable caching. + """ + + def __setitem__(self, key, value) -> None: + pass + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class TTLItem(Generic[VT]): - value: VT = field() - expire: float = field() + value: VT = attrs.field( + repr=False, + ) + expire: float = attrs.field( + repr=False, + ) """When the item expires in cache.""" def is_expired(self, timestamp: float) -> bool: diff --git a/naff/client/utils/input_utils.py b/naff/client/utils/input_utils.py index f105bf041..4925a4e07 100644 --- a/naff/client/utils/input_utils.py +++ b/naff/client/utils/input_utils.py @@ -3,15 +3,14 @@ import aiohttp # type: ignore - -from naff.client.const import logger +from naff.client.const import get_logger __all__ = ("OverriddenJson", "response_decode", "get_args", "get_first_word") try: import orjson as json except ImportError: - logger.warning("orjson not installed, built-in json library will be used") + get_logger().warning("orjson not installed, built-in json library will be used") import json as json diff --git a/naff/client/utils/serializer.py b/naff/client/utils/serializer.py index eb2c499bf..2f997d34e 100644 --- a/naff/client/utils/serializer.py +++ b/naff/client/utils/serializer.py @@ -30,10 +30,15 @@ def to_dict(inst) -> dict: The processed dict. """ + attrs = fields(inst.__class__) + if (converter := getattr(inst, "as_dict", None)) is not None: - return converter() + d = converter() + for a in attrs: + if a.metadata.get("no_export", False): + d.pop(a.name, None) + return d - attrs = fields(inst.__class__) d = {} for a in attrs: diff --git a/naff/client/utils/text_utils.py b/naff/client/utils/text_utils.py new file mode 100644 index 000000000..4d38adfcb --- /dev/null +++ b/naff/client/utils/text_utils.py @@ -0,0 +1,33 @@ +import re +import naff.models as models + +__all__ = ("mentions",) + + +def mentions( + text: str, + query: "str | re.Pattern[str] | models.BaseUser | models.BaseChannel | models.Role", + *, + tag_as_mention: bool = False, +) -> bool: + """Checks whether a query is present in a text. + + Args: + text: The text to search in + query: The query to search for + tag_as_mention: Should `BaseUser.tag` be checked *(only if query is an instance of BaseUser)* + + Returns: + Whether the query could be found in the text + """ + if isinstance(query, str): + return query in text + elif isinstance(query, re.Pattern): + return query.match(text) is not None + elif isinstance(query, models.BaseUser): + # mentions with <@!ID> aren't detected without the replacement + return (query.mention in text.replace("@!", "@")) or (query.tag in text if tag_as_mention else False) + elif isinstance(query, (models.BaseChannel, models.Role)): + return query.mention in text + else: + return False diff --git a/naff/ext/debug_extension/__init__.py b/naff/ext/debug_extension/__init__.py index 137ec4c60..183dc23eb 100644 --- a/naff/ext/debug_extension/__init__.py +++ b/naff/ext/debug_extension/__init__.py @@ -1,8 +1,9 @@ -import logging +import asyncio import platform +import tracemalloc from naff import Client, Extension, listen, slash_command, InteractionContext, Timestamp, TimestampStyles, Intents -from naff.client.const import logger, __version__, __py_version__ +from naff.client.const import get_logger, __version__, __py_version__ from naff.models.naff import checks from .debug_application_cmd import DebugAppCMD from .debug_exec import DebugExec @@ -14,16 +15,25 @@ class DebugExtension(DebugExec, DebugAppCMD, DebugExts, Extension): def __init__(self, bot: Client) -> None: + bot.logger.info("Debug Extension is mounting!") + super().__init__(bot) self.add_ext_check(checks.is_owner()) - logger.info("Debug Extension is growing!") + bot.logger.info("Debug Extension is growing!") + tracemalloc.start() + bot.logger.warning("Tracemalloc started") + + async def async_start(self) -> None: + loop = asyncio.get_running_loop() + loop.set_debug(True) + self.bot.logger.warning("Asyncio debug mode is enabled") @listen() async def on_startup(self) -> None: - logger.info(f"Started {self.bot.user.tag} [{self.bot.user.id}] in Debug Mode") + self.bot.logger.info(f"Started {self.bot.user.tag} [{self.bot.user.id}] in Debug Mode") - logger.info(f"Caching System State: \n{get_cache_state(self.bot)}") + self.bot.logger.info(f"Caching System State: \n{get_cache_state(self.bot)}") @slash_command( "debug", diff --git a/naff/ext/debug_extension/debug_application_cmd.py b/naff/ext/debug_extension/debug_application_cmd.py index 83b2b1373..8d7745226 100644 --- a/naff/ext/debug_extension/debug_application_cmd.py +++ b/naff/ext/debug_extension/debug_application_cmd.py @@ -88,7 +88,7 @@ async def send(cmd_json: dict) -> None: ) if not remote: - data = application_commands_to_dict(self.bot.interactions)[scope] + data = application_commands_to_dict(self.bot.interactions, self.bot)[scope] cmd_obj = self.bot.get_application_cmd_by_id(cmd_id) for cmd in data: if cmd["name"] == cmd_obj.name: diff --git a/naff/ext/debug_extension/utils.py b/naff/ext/debug_extension/utils.py index 538407dcb..5ccc6a484 100644 --- a/naff/ext/debug_extension/utils.py +++ b/naff/ext/debug_extension/utils.py @@ -1,8 +1,9 @@ import datetime import inspect +import weakref from typing import TYPE_CHECKING, Any, Optional, Union -from naff.client.utils.cache import TTLCache +from naff.client.utils.cache import TTLCache, NullCache from naff.models import Embed, MaterialColors if TYPE_CHECKING: @@ -28,19 +29,26 @@ def debug_embed(title: str, **kwargs) -> Embed: def get_cache_state(bot: "Client") -> str: """Create a nicely formatted table of internal cache state.""" - caches = [ - c[0] + caches = { + c[0]: getattr(bot.cache, c[0]) for c in inspect.getmembers(bot.cache, predicate=lambda x: isinstance(x, dict)) if not c[0].startswith("__") - ] + } + caches["endpoints"] = bot.http._endpoints + caches["rate_limits"] = bot.http.ratelimit_locks table = [] - for cache in caches: - val = getattr(bot.cache, cache) + for cache, val in caches.items(): if isinstance(val, TTLCache): amount = [len(val), f"{val.hard_limit}({val.soft_limit})"] expire = f"{val.ttl}s" + elif isinstance(val, NullCache): + amount = ("DISABLED",) + expire = "N/A" + elif isinstance(val, (weakref.WeakValueDictionary, weakref.WeakKeyDictionary)): + amount = [len(val), "∞"] + expire = "w_ref" else: amount = [len(val), "∞"] expire = "none" @@ -48,10 +56,6 @@ def get_cache_state(bot: "Client") -> str: row = [cache.removesuffix("_cache"), amount, expire] table.append(row) - # http caches - table.append(["endpoints", [len(bot.http._endpoints), "∞"], "none"]) - table.append(["ratelimits", [len(bot.http.ratelimit_locks), "∞"], "w_ref"]) - adjust_subcolumn(table, 1, aligns=[">", "<"]) labels = ["Cache", "Amount", "Expire"] diff --git a/naff/ext/jurigged.py b/naff/ext/jurigged.py new file mode 100644 index 000000000..98913b6ab --- /dev/null +++ b/naff/ext/jurigged.py @@ -0,0 +1,209 @@ +import inspect +from pathlib import Path +from types import ModuleType +from typing import Callable, Dict + +from naff import Extension, SlashCommand, listen +from naff.client.errors import ExtensionLoadException, ExtensionNotFound +from naff.client.utils.misc_utils import find +from naff.client.const import get_logger + +try: + from jurigged import watch, CodeFile + from jurigged.live import WatchOperation + from jurigged.codetools import ( + AddOperation, + DeleteOperation, + UpdateOperation, + LineDefinition, + ) +except ModuleNotFoundError: + get_logger().error( + "jurigged not installed, cannot enable jurigged integration. Install with `pip install naff[jurigged]`" + ) + raise + + +__all__ = ("Jurigged", "setup") + + +def get_all_commands(module: ModuleType) -> Dict[str, Callable]: + """ + Get all SlashCommands from a specified module. + + Args: + module: Module to extract commands from + """ + commands = {} + + def is_extension(e) -> bool: + """Check that an object is an extension.""" + return inspect.isclass(e) and issubclass(e, Extension) and e is not Extension + + def is_slashcommand(e) -> bool: + """Check that an object is a slash command.""" + return isinstance(e, SlashCommand) + + for _name, item in inspect.getmembers(module, is_extension): + inspect_result = inspect.getmembers(item, is_slashcommand) + exts = [] + for _, val in inspect_result: + exts.append(val) + commands[f"{module.__name__}"] = exts + + return {k: v for k, v in commands.items() if v is not None} + + +class Jurigged(Extension): + @listen(event_name="on_startup") + async def jurigged_startup(self) -> None: + """Jurigged starting utility.""" + self.command_cache = {} + self.bot.logger.warning("Setting sync_ext to True by default for syncing changes") + self.bot.sync_ext = True + + self.bot.logger.info("Loading jurigged") + path = Path().resolve() + self.watcher = watch(f"{path}/[!.]*.py", logger=self.jurigged_log) + self.watcher.prerun.register(self.jurigged_prerun) + self.watcher.postrun.register(self.jurigged_postrun) + + def jurigged_log(self, event: WatchOperation | AddOperation | DeleteOperation | UpdateOperation) -> None: + """ + Log a jurigged event + + Args: + event: jurigged event + """ + if isinstance(event, WatchOperation): + self.bot.logger.debug(f"Watch {event.filename}") + elif isinstance(event, (Exception, SyntaxError)): + self.bot.logger.exception("Jurigged encountered an error", exc_info=True) + else: + event_str = "{action} {dotpath}:{lineno}{extra}" + action = None + lineno = event.defn.stashed.lineno + dotpath = event.defn.dotpath() + extra = "" + + if isinstance(event.defn, LineDefinition): + dotpath = event.defn.parent.dotpath() + extra = f" | {event.defn.text}" + + if isinstance(event, AddOperation): + action = "Add" + if isinstance(event.defn, LineDefinition): + action = "Run" + elif isinstance(event, UpdateOperation): + action = "Update" + elif isinstance(event, DeleteOperation): + action = "Delete" + if not action: + self.bot.logger.debug(event) + else: + self.bot.logger.debug(event_str.format(action=action, dotpath=dotpath, lineno=lineno, extra=extra)) + + def jurigged_prerun(self, _path: str, cf: CodeFile) -> None: + """ + Jurigged prerun event. + + Args: + path: Path to file + cf: File information + """ + if self.bot.get_ext(cf.module_name): + self.bot.logger.debug(f"Caching {cf.module_name}") + self.command_cache = get_all_commands(cf.module) + + def jurigged_postrun(self, _path: str, cf: CodeFile) -> None: + """ + Jurigged postrun event. + + Args: + path: Path to file + cf: File information + """ + if self.bot.get_ext(cf.module_name): + self.bot.logger.debug(f"Checking {cf.module_name}") + commands = get_all_commands(cf.module) + + self.bot.logger.debug("Checking for changes") + for module, cmds in commands.items(): + # Check if a module was removed + if module not in commands: + self.bot.logger.debug(f"Module {module} removed") + self.bot.unload_extension(module) + + # Check if a module is new + elif module not in self.command_cache: + self.bot.logger.debug(f"Module {module} added") + try: + self.bot.load_extension(module) + except ExtensionLoadException: + self.bot.logger.warning(f"Failed to load new module {module}") + + # Check if a module has more/less commands + elif len(self.command_cache[module]) != len(cmds): + self.bot.logger.debug("Number of commands changed, reloading") + try: + self.bot.reload_extension(module) + except ExtensionNotFound: + try: + self.bot.load_extension(module) + except ExtensionLoadException: + self.bot.logger.warning(f"Failed to update module {module}") + except ExtensionLoadException: + self.bot.logger.warning(f"Failed to update module {module}") + + # Check each command for differences + else: + for cmd in cmds: + old_cmd = find( + lambda x, cmd=cmd: x.resolved_name == cmd.resolved_name, + self.command_cache[module], + ) + + # Extract useful info + old_args = old_cmd.options + old_arg_names = [] + new_arg_names = [] + if old_args: + old_arg_names = [x.name.default for x in old_args] + new_args = cmd.options + if new_args: + new_arg_names = [x.name.default for x in new_args] + + # No changes + if not old_args and not new_args: + continue + + # Check if number of args has changed + if len(old_arg_names) != len(new_arg_names): + self.bot.logger.debug("Number of arguments changed, reloading") + try: + self.bot.reload_extension(module) + except Exception: + self.bot.logger.exception(f"Failed to update module {module}", exc_info=True) + + # Check if arg names have changed + elif len(set(old_arg_names) - set(new_arg_names)) > 0: + self.bot.logger.debug("Argument names changed, reloading") + try: + self.bot.reload_extension(module) + except Exception: + self.bot.logger.exception(f"Failed to update module {module}", exc_info=True) + + # Check if arg types have changed + elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)): + self.bot.logger.debug("Argument types changed, reloading") + try: + self.bot.reload_extension(module) + except Exception: + self.bot.logger.exception(f"Failed to update module {module}", exc_info=True) + else: + self.bot.logger.debug("No changes detected") + self.command_cache.clear() + + +def setup(bot) -> None: + Jurigged(bot) diff --git a/naff/ext/paginators.py b/naff/ext/paginators.py index e66ae390d..7195f5e9e 100644 --- a/naff/ext/paginators.py +++ b/naff/ext/paginators.py @@ -3,6 +3,8 @@ import uuid from typing import Callable, Coroutine, List, Optional, Sequence, TYPE_CHECKING, Union +import attrs + from naff import ( Embed, ComponentContext, @@ -16,12 +18,11 @@ Message, MISSING, Snowflake_Type, - Select, + StringSelectMenu, SelectOption, Color, BrandColors, ) -from naff.client.utils.attr_utils import define, field from naff.client.utils.serializer import export_converter from naff.models.discord.emoji import process_emoji @@ -32,11 +33,13 @@ __all__ = ("Paginator",) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Timeout: - paginator: "Paginator" = field() + paginator: "Paginator" = attrs.field( + repr=False, + ) """The paginator that this timeout is associated with.""" - run: bool = field(default=True) + run: bool = attrs.field(repr=False, default=True) """Whether or not this timeout is currently running.""" ping: asyncio.Event = asyncio.Event() """The event that is used to wait the paginator action.""" @@ -53,15 +56,17 @@ async def __call__(self) -> None: self.ping.clear() -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Page: - content: str = field() + content: str = attrs.field( + repr=False, + ) """The content of the page.""" - title: Optional[str] = field(default=None) + title: Optional[str] = attrs.field(repr=False, default=None) """The title of the page.""" - prefix: str = field(kw_only=True, default="") + prefix: str = attrs.field(repr=False, kw_only=True, default="") """Content that is prepended to the page.""" - suffix: str = field(kw_only=True, default="") + suffix: str = attrs.field(repr=False, kw_only=True, default="") """Content that is appended to the page.""" @property @@ -74,68 +79,70 @@ def to_embed(self) -> Embed: return Embed(description=f"{self.prefix}\n{self.content}\n{self.suffix}", title=self.title) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Paginator: - client: "Client" = field() + client: "Client" = attrs.field( + repr=False, + ) """The NAFF client to hook listeners into""" - page_index: int = field(kw_only=True, default=0) + page_index: int = attrs.field(repr=False, kw_only=True, default=0) """The index of the current page being displayed""" - pages: Sequence[Page | Embed] = field(factory=list, kw_only=True) + pages: Sequence[Page | Embed] = attrs.field(repr=False, factory=list, kw_only=True) """The pages this paginator holds""" - timeout_interval: int = field(default=0, kw_only=True) + timeout_interval: int = attrs.field(repr=False, default=0, kw_only=True) """How long until this paginator disables itself""" - callback: Callable[..., Coroutine] = field(default=None) + callback: Callable[..., Coroutine] = attrs.field(repr=False, default=None) """A coroutine to call should the select button be pressed""" - show_first_button: bool = field(default=True) + show_first_button: bool = attrs.field(repr=False, default=True) """Should a `First` button be shown""" - show_back_button: bool = field(default=True) + show_back_button: bool = attrs.field(repr=False, default=True) """Should a `Back` button be shown""" - show_next_button: bool = field(default=True) + show_next_button: bool = attrs.field(repr=False, default=True) """Should a `Next` button be shown""" - show_last_button: bool = field(default=True) + show_last_button: bool = attrs.field(repr=False, default=True) """Should a `Last` button be shown""" - show_callback_button: bool = field(default=False) + show_callback_button: bool = attrs.field(repr=False, default=False) """Show a button which will call the `callback`""" - show_select_menu: bool = field(default=False) + show_select_menu: bool = attrs.field(repr=False, default=False) """Should a select menu be shown for navigation""" - first_button_emoji: Optional[Union["PartialEmoji", dict, str]] = field( - default="⏎ī¸", metadata=export_converter(process_emoji) + first_button_emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( + repr=False, default="⏎ī¸", metadata=export_converter(process_emoji) ) """The emoji to use for the first button""" - back_button_emoji: Optional[Union["PartialEmoji", dict, str]] = field( - default="âŦ…ī¸", metadata=export_converter(process_emoji) + back_button_emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( + repr=False, default="âŦ…ī¸", metadata=export_converter(process_emoji) ) """The emoji to use for the back button""" - next_button_emoji: Optional[Union["PartialEmoji", dict, str]] = field( - default="➡ī¸", metadata=export_converter(process_emoji) + next_button_emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( + repr=False, default="➡ī¸", metadata=export_converter(process_emoji) ) """The emoji to use for the next button""" - last_button_emoji: Optional[Union["PartialEmoji", dict, str]] = field( - default="⏊", metadata=export_converter(process_emoji) + last_button_emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( + repr=False, default="⏊", metadata=export_converter(process_emoji) ) """The emoji to use for the last button""" - callback_button_emoji: Optional[Union["PartialEmoji", dict, str]] = field( - default="✅", metadata=export_converter(process_emoji) + callback_button_emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( + repr=False, default="✅", metadata=export_converter(process_emoji) ) """The emoji to use for the callback button""" - wrong_user_message: str = field(default="This paginator is not for you") + wrong_user_message: str = attrs.field(repr=False, default="This paginator is not for you") """The message to be sent when the wrong user uses this paginator""" - default_title: Optional[str] = field(default=None) + default_title: Optional[str] = attrs.field(repr=False, default=None) """The default title to show on the embeds""" - default_color: Color = field(default=BrandColors.BLURPLE) + default_color: Color = attrs.field(repr=False, default=BrandColors.BLURPLE) """The default colour to show on the embeds""" - default_button_color: Union[ButtonStyles, int] = field(default=ButtonStyles.BLURPLE) + default_button_color: Union[ButtonStyles, int] = attrs.field(repr=False, default=ButtonStyles.BLURPLE) """The color of the buttons""" - _uuid: str = field(factory=uuid.uuid4) - _message: Message = field(default=MISSING) - _timeout_task: Timeout = field(default=MISSING) - _author_id: Snowflake_Type = field(default=MISSING) + _uuid: str = attrs.field(repr=False, factory=uuid.uuid4) + _message: Message = attrs.field(repr=False, default=MISSING) + _timeout_task: Timeout = attrs.field(repr=False, default=MISSING) + _author_id: Snowflake_Type = attrs.field(repr=False, default=MISSING) def __attrs_post_init__(self) -> None: self.client.add_component_callback( @@ -257,7 +264,7 @@ def create_components(self, disable: bool = False) -> List[ActionRow]: if self.show_select_menu: current = self.pages[self.page_index] output.append( - Select( + StringSelectMenu( [ SelectOption(f"{i+1} {p.get_summary if isinstance(p, Page) else p.title}", str(i)) for i, p in enumerate(self.pages) diff --git a/naff/ext/prefixed_help.py b/naff/ext/prefixed_help.py index cba413ab8..8dcd80bfa 100644 --- a/naff/ext/prefixed_help.py +++ b/naff/ext/prefixed_help.py @@ -1,9 +1,10 @@ import functools +from logging import Logger from typing import TYPE_CHECKING import attrs -from naff import Embed -from naff.client.const import logger + +from naff import Embed, get_logger from naff.ext.paginators import Paginator from naff.models.discord.color import BrandColors, Color from naff.models.naff.context import PrefixedContext @@ -15,7 +16,7 @@ __all__ = ("PrefixedHelpCommand",) -@attrs.define(slots=True) +@attrs.define(eq=False, order=False, hash=False, slots=True) class PrefixedHelpCommand: """A help command for all prefixed commands in a bot.""" @@ -49,6 +50,7 @@ class PrefixedHelpCommand: """The text to display when a command does not have a brief string defined.""" _cmd: PrefixedCommand = attrs.field(init=False, default=None) + logger: Logger = attrs.field(init=False, factory=get_logger) def __attrs_post_init__(self) -> None: if not self._cmd: @@ -62,7 +64,7 @@ def register(self) -> None: # replace existing help command if found if "help" in self.client.prefixed_commands: - logger.warning("Replacing existing help command.") + self.logger.warning("Replacing existing help command.") del self.client.prefixed_commands["help"] self.client.add_prefixed_command(self._cmd) # type: ignore diff --git a/naff/ext/sentry.py b/naff/ext/sentry.py index 0d54a6a9e..e2b7a5b8b 100644 --- a/naff/ext/sentry.py +++ b/naff/ext/sentry.py @@ -9,13 +9,15 @@ from typing import Any, Callable, Optional from naff.api.events.internal import Error -from naff.client.const import logger +from naff.client.const import get_logger from naff.models.naff.tasks.task import Task try: import sentry_sdk except ModuleNotFoundError: - logger.error("sentry-sdk not installed, cannot enable sentry integration. Install with `pip install naff[sentry]`") + get_logger().error( + "sentry-sdk not installed, cannot enable sentry integration. Install with `pip install naff[sentry]`" + ) raise from naff import Client, Extension, listen @@ -53,7 +55,7 @@ async def on_startup(self) -> None: ) sentry_sdk.set_tag("bot_name", str(self.bot.user)) - @listen() + @listen(disable_default_listeners=False) async def on_error(self, event: Error) -> None: with sentry_sdk.configure_scope() as scope: scope.set_tag("source", event.source) @@ -66,6 +68,8 @@ async def on_error(self, event: Error) -> None: "message": event.ctx.message, }, ) + if event.ctx.author: + scope.set_user({"id": event.ctx.author.id, "username": event.ctx.author.tag}) sentry_sdk.capture_exception(event.error) @@ -89,7 +93,7 @@ def setup( filter: Optional[Callable[[dict[str, Any], dict[str, Any]], Optional[dict[str, Any]]]] = None, ) -> None: if not token: - logger.error("Cannot enable sentry integration, no token provided") + bot.logger.error("Cannot enable sentry integration, no token provided") return if filter is None: filter = default_sentry_filter diff --git a/naff/models/discord/__init__.py b/naff/models/discord/__init__.py index 7676ad1f8..a7bcbaa42 100644 --- a/naff/models/discord/__init__.py +++ b/naff/models/discord/__init__.py @@ -1,6 +1,8 @@ from .activity import * +from .app_perms import * from .application import * from .asset import * +from .auto_mod import * from .channel import * from .color import * from .components import * @@ -24,4 +26,3 @@ from .user import * from .voice_state import * from .webhooks import * -from .auto_mod import * diff --git a/naff/models/discord/activity.py b/naff/models/discord/activity.py index f9b0f8b6d..fade60376 100644 --- a/naff/models/discord/activity.py +++ b/naff/models/discord/activity.py @@ -1,7 +1,8 @@ from typing import Optional, List +import attrs + from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import timestamp_converter, optional from naff.client.utils.serializer import dict_filter_none from naff.models.discord.emoji import PartialEmoji @@ -18,77 +19,83 @@ ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ActivityTimestamps(DictSerializationMixin): - start: Optional[Timestamp] = field(default=None, converter=optional(timestamp_converter)) + start: Optional[Timestamp] = attrs.field(repr=False, default=None, converter=optional(timestamp_converter)) """The start time of the activity. Shows "elapsed" timer on discord client.""" - end: Optional[Timestamp] = field(default=None, converter=optional(timestamp_converter)) + end: Optional[Timestamp] = attrs.field(repr=False, default=None, converter=optional(timestamp_converter)) """The end time of the activity. Shows "remaining" timer on discord client.""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ActivityParty(DictSerializationMixin): - id: Optional[str] = field(default=None) + id: Optional[str] = attrs.field(repr=False, default=None) """A unique identifier for this party""" - size: Optional[List[int]] = field(default=None) + size: Optional[List[int]] = attrs.field(repr=False, default=None) """Info about the size of the party""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ActivityAssets(DictSerializationMixin): - large_image: Optional[str] = field(default=None) + large_image: Optional[str] = attrs.field(repr=False, default=None) """The large image for this activity. Uses discord's asset image url format.""" - large_text: Optional[str] = field(default=None) + large_text: Optional[str] = attrs.field(repr=False, default=None) """Hover text for the large image""" - small_image: Optional[str] = field(default=None) + small_image: Optional[str] = attrs.field(repr=False, default=None) """The large image for this activity. Uses discord's asset image url format.""" - small_text: Optional[str] = field(default=None) + small_text: Optional[str] = attrs.field(repr=False, default=None) """Hover text for the small image""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ActivitySecrets(DictSerializationMixin): - join: Optional[str] = field(default=None) + join: Optional[str] = attrs.field(repr=False, default=None) """The secret for joining a party""" - spectate: Optional[str] = field(default=None) + spectate: Optional[str] = attrs.field(repr=False, default=None) """The secret for spectating a party""" - match: Optional[str] = field(default=None) + match: Optional[str] = attrs.field(repr=False, default=None) """The secret for a specific instanced match""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Activity(DictSerializationMixin): """Represents a discord activity object use for rich presence in discord.""" - name: str = field(repr=True) + name: str = attrs.field(repr=True) """The activity's name""" - type: ActivityType = field(repr=True, default=ActivityType.GAME) + type: ActivityType = attrs.field(repr=True, default=ActivityType.GAME) """The type of activity""" - url: Optional[str] = field(repr=True, default=None) + url: Optional[str] = attrs.field(repr=True, default=None) """Stream url, is validated when type is 1""" - created_at: Optional[Timestamp] = field(repr=True, default=None, converter=optional(timestamp_converter)) + created_at: Optional[Timestamp] = attrs.field(repr=True, default=None, converter=optional(timestamp_converter)) """When the activity was added to the user's session""" - timestamps: Optional[ActivityTimestamps] = field(default=None, converter=optional(ActivityTimestamps.from_dict)) + timestamps: Optional[ActivityTimestamps] = attrs.field( + repr=False, default=None, converter=optional(ActivityTimestamps.from_dict) + ) """Start and/or end of the game""" - application_id: "Snowflake_Type" = field(default=None) + application_id: "Snowflake_Type" = attrs.field(repr=False, default=None) """Application id for the game""" - details: Optional[str] = field(default=None) + details: Optional[str] = attrs.field(repr=False, default=None) """What the player is currently doing""" - state: Optional[str] = field(default=None) + state: Optional[str] = attrs.field(repr=False, default=None) """The user's current party status""" - emoji: Optional[PartialEmoji] = field(default=None, converter=optional(PartialEmoji.from_dict)) + emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict)) """The emoji used for a custom status""" - party: Optional[ActivityParty] = field(default=None, converter=optional(ActivityParty.from_dict)) + party: Optional[ActivityParty] = attrs.field(repr=False, default=None, converter=optional(ActivityParty.from_dict)) """Information for the current party of the player""" - assets: Optional[ActivityAssets] = field(default=None, converter=optional(ActivityAssets.from_dict)) + assets: Optional[ActivityAssets] = attrs.field( + repr=False, default=None, converter=optional(ActivityAssets.from_dict) + ) """Assets to display on the player's profile""" - secrets: Optional[ActivitySecrets] = field(default=None, converter=optional(ActivitySecrets.from_dict)) + secrets: Optional[ActivitySecrets] = attrs.field( + repr=False, default=None, converter=optional(ActivitySecrets.from_dict) + ) """Secrets for Rich Presence joining and spectating""" - instance: Optional[bool] = field(default=False) + instance: Optional[bool] = attrs.field(repr=False, default=False) """Whether or not the activity is an instanced game session""" - flags: Optional[ActivityFlags] = field(default=None, converter=optional(ActivityFlags)) + flags: Optional[ActivityFlags] = attrs.field(repr=False, default=None, converter=optional(ActivityFlags)) """Activity flags bitwise OR together, describes what the payload includes""" - buttons: List[str] = field(factory=list) + buttons: List[str] = attrs.field(repr=False, factory=list) """The custom buttons shown in the Rich Presence (max 2)""" @classmethod diff --git a/naff/models/discord/app_perms.py b/naff/models/discord/app_perms.py new file mode 100644 index 000000000..ee2f9992c --- /dev/null +++ b/naff/models/discord/app_perms.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING + +import attrs + +from naff.models.discord.base import DiscordObject, ClientObject +from naff.models.discord.enums import InteractionPermissionTypes +from naff.models.discord.snowflake import to_snowflake + +if TYPE_CHECKING: + from naff import Snowflake_Type, Guild + +__all__ = ("ApplicationCommandPermission",) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class ApplicationCommandPermission(DiscordObject): + id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) + """ID of the role user or channel""" + type: InteractionPermissionTypes = attrs.field(repr=False, converter=InteractionPermissionTypes) + """Type of permission (role user or channel)""" + permission: bool = attrs.field(repr=False, default=False) + """Whether the command is enabled for this permission""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class CommandPermissions(ClientObject): + command_id: "Snowflake_Type" = attrs.field( + repr=False, + ) + _guild: "Guild" = attrs.field( + repr=False, + ) + + permissions: dict["Snowflake_Type", ApplicationCommandPermission] = attrs.field( + repr=False, factory=dict, init=False + ) + + def is_enabled(self, *object_id) -> bool: + """ + Check if a command is enabled for given scope(s). Takes into account the permissions for the bot itself + + Args: + *object_id: The object(s) ID to check for. + + Returns: + Whether the command is enabled for the given scope(s). + """ + bot_perms = self._guild.command_permissions.get(self._client.app.id) + + for obj_id in object_id: + obj_id = to_snowflake(obj_id) + if permission := self.permissions.get(obj_id): + if not permission.permission: + return False + + if bot_perms: + if permission := bot_perms.permissions.get(obj_id): + if not permission.permission: + return False + return True + + def is_enabled_in_context(self, context) -> bool: + """ + Check if a command is enabled for the given context. + + Args: + context: The context to check for. + + Returns: + Whether the command is enabled for the given context. + """ + everyone_role = context.guild.id + all_channels = context.guild.id - 1 # why tf discord + return self.is_enabled( + context.channel.id, *context.author.roles, context.author.id, everyone_role, all_channels + ) + + def update_permissions(self, *permissions: ApplicationCommandPermission) -> None: + """ + Update the permissions for the command. + + Args: + permissions: The permission to set. + """ + self.permissions = {perm.id: perm for perm in permissions} diff --git a/naff/models/discord/application.py b/naff/models/discord/application.py index 6f589094c..2a6265e06 100644 --- a/naff/models/discord/application.py +++ b/naff/models/discord/application.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING, List, Optional, Dict, Any +import attrs + from naff.client.const import MISSING -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional from naff.models.discord.asset import Asset from naff.models.discord.enums import ApplicationFlags @@ -16,50 +17,52 @@ __all__ = ("Application",) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Application(DiscordObject): """Represents a discord application.""" - name: str = field(repr=True) + name: str = attrs.field(repr=True) """The name of the application""" - icon: Optional[Asset] = field(default=None) + icon: Optional[Asset] = attrs.field(repr=False, default=None) """The icon of the application""" - description: Optional[str] = field(default=None) + description: Optional[str] = attrs.field(repr=False, default=None) """The description of the application""" - rpc_origins: Optional[List[str]] = field(default=None) + rpc_origins: Optional[List[str]] = attrs.field(repr=False, default=None) """An array of rpc origin urls, if rpc is enabled""" - bot_public: bool = field(default=True) + bot_public: bool = attrs.field(repr=False, default=True) """When false only app owner can join the app's bot to guilds""" - bot_require_code_grant: bool = field(default=False) + bot_require_code_grant: bool = attrs.field(repr=False, default=False) """When true the app's bot will only join upon completion of the full oauth2 code grant flow""" - terms_of_service_url: Optional[str] = field(default=None) + terms_of_service_url: Optional[str] = attrs.field(repr=False, default=None) """The url of the app's terms of service""" - privacy_policy_url: Optional[str] = field(default=None) + privacy_policy_url: Optional[str] = attrs.field(repr=False, default=None) """The url of the app's privacy policy""" - owner_id: Optional[Snowflake_Type] = field(default=None, converter=optional(to_snowflake)) + owner_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) """The id of the owner of the application""" - summary: str = field() + summary: str = attrs.field( + repr=False, + ) """If this application is a game sold on Discord, this field will be the summary field for the store page of its primary sku""" - verify_key: Optional[str] = field(default=MISSING) + verify_key: Optional[str] = attrs.field(repr=False, default=MISSING) """The hex encoded key for verification in interactions and the GameSDK's GetTicket""" - team: Optional["Team"] = field(default=None) + team: Optional["Team"] = attrs.field(repr=False, default=None) """If the application belongs to a team, this will be a list of the members of that team""" - guild_id: Optional["Snowflake_Type"] = field(default=None) + guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """If this application is a game sold on Discord, this field will be the guild to which it has been linked""" - primary_sku_id: Optional["Snowflake_Type"] = field(default=None) + primary_sku_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """If this application is a game sold on Discord, this field will be the id of the "Game SKU" that is created, if exists""" - slug: Optional[str] = field(default=None) + slug: Optional[str] = attrs.field(repr=False, default=None) """If this application is a game sold on Discord, this field will be the URL slug that links to the store page""" - cover_image: Optional[Asset] = field(default=None) + cover_image: Optional[Asset] = attrs.field(repr=False, default=None) """The application's default rich presence invite cover""" - flags: Optional["ApplicationFlags"] = field(default=None, converter=optional(ApplicationFlags)) + flags: Optional["ApplicationFlags"] = attrs.field(repr=False, default=None, converter=optional(ApplicationFlags)) """The application's public flags""" - tags: Optional[List[str]] = field(default=None) + tags: Optional[List[str]] = attrs.field(repr=False, default=None) """The application's tags describing its functionality and content""" # todo: implement an ApplicationInstallParams object. See https://discord.com/developers/docs/resources/application#install-params-object - install_params: Optional[dict] = field(default=None) + install_params: Optional[dict] = attrs.field(repr=False, default=None) """The application's settings for in-app invitation to guilds""" - custom_install_url: Optional[str] = field(default=None) + custom_install_url: Optional[str] = attrs.field(repr=False, default=None) """The application's custom authorization link for invitation to a guild""" @classmethod diff --git a/naff/models/discord/asset.py b/naff/models/discord/asset.py index b6cc5f8e2..9b1f5d2e5 100644 --- a/naff/models/discord/asset.py +++ b/naff/models/discord/asset.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Optional, Union -from naff.client.utils.attr_utils import define, field +import attrs + from naff.client.utils.serializer import no_export_meta if TYPE_CHECKING: @@ -11,7 +12,7 @@ __all__ = ("Asset",) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Asset: """ Represents a discord asset. @@ -25,9 +26,9 @@ class Asset: BASE = "https://cdn.discordapp.com" - _client: "Client" = field(metadata=no_export_meta) - _url: str = field(repr=True) - hash: Optional[str] = field(repr=True, default=None) + _client: "Client" = attrs.field(repr=False, metadata=no_export_meta) + _url: str = attrs.field(repr=True) + hash: Optional[str] = attrs.field(repr=True, default=None) @classmethod def from_path_hash(cls, client: "Client", path: str, asset_hash: str) -> "Asset": diff --git a/naff/models/discord/auto_mod.py b/naff/models/discord/auto_mod.py index c32fe0bae..898d13285 100644 --- a/naff/models/discord/auto_mod.py +++ b/naff/models/discord/auto_mod.py @@ -2,10 +2,10 @@ import attrs -from naff.client.const import logger, MISSING, Absent +from naff.client.const import get_logger, MISSING, Absent from naff.client.mixins.serialization import DictSerializationMixin from naff.client.utils import list_converter, optional -from naff.client.utils.attr_utils import define, field, docs +from naff.client.utils.attr_utils import docs from naff.models.discord.base import ClientObject, DiscordObject from naff.models.discord.enums import AutoModTriggerType, AutoModAction, AutoModEvent, AutoModLanuguageType from naff.models.discord.snowflake import to_snowflake_list, to_snowflake @@ -23,7 +23,7 @@ __all__ = ("AutoModerationAction", "AutoModRule") -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseAction(DictSerializationMixin): """A base implementation of a moderation action @@ -31,13 +31,13 @@ class BaseAction(DictSerializationMixin): type: The type of action that was taken """ - type: AutoModAction = field(converter=AutoModAction) + type: AutoModAction = attrs.field(repr=False, converter=AutoModAction) @classmethod def from_dict_factory(cls, data: dict) -> "BaseAction": action_class = ACTION_MAPPING.get(data.get("type")) if not action_class: - logger.error(f"Unknown action type for {data}") + get_logger().error(f"Unknown action type for {data}") action_class = cls return action_class.from_dict({"type": data.get("type")} | data["metadata"]) @@ -48,7 +48,7 @@ def as_dict(self) -> dict: return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseTrigger(DictSerializationMixin): """A base implementation of an auto-mod trigger @@ -56,7 +56,9 @@ class BaseTrigger(DictSerializationMixin): type: The type of event this trigger is for """ - type: AutoModTriggerType = field(converter=AutoModTriggerType, repr=True, metadata=docs("The type of trigger")) + type: AutoModTriggerType = attrs.field( + converter=AutoModTriggerType, repr=True, metadata=docs("The type of trigger") + ) @classmethod def _process_dict(cls, data: dict[str, Any]) -> dict[str, Any]: @@ -73,7 +75,7 @@ def from_dict_factory(cls, data: dict) -> "BaseAction": trigger_class = TRIGGER_MAPPING.get(data.get("trigger_type")) meta = data.get("trigger_metadata", {}) if not trigger_class: - logger.error(f"Unknown trigger type for {data}") + get_logger().error(f"Unknown trigger type for {data}") trigger_class = cls payload = {"type": data.get("trigger_type"), "trigger_metadata": meta} @@ -93,26 +95,26 @@ def _keyword_converter(filter: str | list[str]) -> list[str]: return [filter] -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class KeywordTrigger(BaseTrigger): """A trigger that checks if content contains words from a user defined list of keywords""" - type: AutoModTriggerType = field( + type: AutoModTriggerType = attrs.field( default=AutoModTriggerType.KEYWORD, converter=AutoModTriggerType, repr=True, metadata=docs("The type of trigger"), ) - keyword_filter: str | list[str] = field( + keyword_filter: str | list[str] = attrs.field( factory=list, repr=True, metadata=docs("What words will trigger this"), converter=_keyword_converter ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class HarmfulLinkFilter(BaseTrigger): """A trigger that checks if content contains any harmful links""" - type: AutoModTriggerType = field( + type: AutoModTriggerType = attrs.field( default=AutoModTriggerType.HARMFUL_LINK, converter=AutoModTriggerType, repr=True, @@ -121,17 +123,17 @@ class HarmfulLinkFilter(BaseTrigger): ... -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class KeywordPresetTrigger(BaseTrigger): """A trigger that checks if content contains words from internal pre-defined wordsets""" - type: AutoModTriggerType = field( + type: AutoModTriggerType = attrs.field( default=AutoModTriggerType.KEYWORD_PRESET, converter=AutoModTriggerType, repr=True, metadata=docs("The type of trigger"), ) - keyword_lists: list[AutoModLanuguageType] = field( + keyword_lists: list[AutoModLanuguageType] = attrs.field( factory=list, converter=list_converter(AutoModLanuguageType), repr=True, @@ -139,62 +141,70 @@ class KeywordPresetTrigger(BaseTrigger): ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class MentionSpamTrigger(BaseTrigger): """A trigger that checks if content contains more mentions than allowed""" - mention_total_limit: int = field(default=3, repr=True, metadata=docs("The maximum number of mentions allowed")) + mention_total_limit: int = attrs.field( + default=3, repr=True, metadata=docs("The maximum number of mentions allowed") + ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BlockMessage(BaseAction): """blocks the content of a message according to the rule""" - type: AutoModAction = field(default=AutoModAction.BLOCK_MESSAGE, converter=AutoModAction) + type: AutoModAction = attrs.field(repr=False, default=AutoModAction.BLOCK_MESSAGE, converter=AutoModAction) ... -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AlertMessage(BaseAction): """logs user content to a specified channel""" - channel_id: "Snowflake_Type" = field(repr=True) - type: AutoModAction = field(default=AutoModAction.ALERT_MESSAGE, converter=AutoModAction) + channel_id: "Snowflake_Type" = attrs.field(repr=True) + type: AutoModAction = attrs.field(repr=False, default=AutoModAction.ALERT_MESSAGE, converter=AutoModAction) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class TimeoutUser(BaseAction): """timeout user for a specified duration""" - duration_seconds: int = field(repr=True, default=60) - type: AutoModAction = field(default=AutoModAction.TIMEOUT_USER, converter=AutoModAction) + duration_seconds: int = attrs.field(repr=True, default=60) + type: AutoModAction = attrs.field(repr=False, default=AutoModAction.TIMEOUT_USER, converter=AutoModAction) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AutoModRule(DiscordObject): """A representation of an auto mod rule""" - name: str = field() + name: str = attrs.field( + repr=False, + ) """The name of the rule""" - enabled: bool = field(default=False) + enabled: bool = attrs.field(repr=False, default=False) """whether the rule is enabled""" - actions: list[BaseAction] = field(factory=list) + actions: list[BaseAction] = attrs.field(repr=False, factory=list) """the actions which will execute when the rule is triggered""" - event_type: AutoModEvent = field() + event_type: AutoModEvent = attrs.field( + repr=False, + ) """the rule event type""" - trigger: BaseTrigger = field() + trigger: BaseTrigger = attrs.field( + repr=False, + ) """The trigger for this rule""" - exempt_roles: list["Snowflake_Type"] = field(factory=list, converter=to_snowflake_list) + exempt_roles: list["Snowflake_Type"] = attrs.field(repr=False, factory=list, converter=to_snowflake_list) """the role ids that should not be affected by the rule (Maximum of 20)""" - exempt_channels: list["Snowflake_Type"] = field(factory=list, converter=to_snowflake_list) + exempt_channels: list["Snowflake_Type"] = attrs.field(repr=False, factory=list, converter=to_snowflake_list) """the channel ids that should not be affected by the rule (Maximum of 50)""" - _guild_id: "Snowflake_Type" = field(default=MISSING) + _guild_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING) """the guild which this rule belongs to""" - _creator_id: "Snowflake_Type" = field(default=MISSING) + _creator_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING) """the user which first created this rule""" - id: "Snowflake_Type" = field(default=MISSING, converter=optional(to_snowflake)) + id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING, converter=optional(to_snowflake)) @classmethod def _process_dict(cls, data: dict, client: "Client") -> dict: @@ -282,21 +292,25 @@ async def modify( return AutoModRule.from_dict(out, self._client) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AutoModerationAction(ClientObject): - rule_trigger_type: AutoModTriggerType = field(converter=AutoModTriggerType) - rule_id: "Snowflake_Type" = field() + rule_trigger_type: AutoModTriggerType = attrs.field(repr=False, converter=AutoModTriggerType) + rule_id: "Snowflake_Type" = attrs.field( + repr=False, + ) - action: BaseAction = field(default=MISSING, repr=True) + action: BaseAction = attrs.field(default=MISSING, repr=True) - matched_keyword: str = field(repr=True) - matched_content: Optional[str] = field(default=None) - content: Optional[str] = field(default=None) + matched_keyword: str = attrs.field(repr=True) + matched_content: Optional[str] = attrs.field(repr=False, default=None) + content: Optional[str] = attrs.field(repr=False, default=None) - _message_id: Optional["Snowflake_Type"] = field(default=None) - _alert_system_message_id: Optional["Snowflake_Type"] = field(default=None) - _channel_id: Optional["Snowflake_Type"] = field(default=None) - _guild_id: "Snowflake_Type" = field() + _message_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) + _alert_system_message_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) + _channel_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) + _guild_id: "Snowflake_Type" = attrs.field( + repr=False, + ) @classmethod def _process_dict(cls, data: dict, client: "Client") -> dict: diff --git a/naff/models/discord/base.py b/naff/models/discord/base.py index f1d0759d4..084bd4b1e 100644 --- a/naff/models/discord/base.py +++ b/naff/models/discord/base.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Type +import attrs + from naff.client.const import T from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.serializer import no_export_meta from naff.models.discord.snowflake import SnowflakeObject @@ -12,11 +13,11 @@ __all__ = ("ClientObject", "DiscordObject") -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class ClientObject(DictSerializationMixin): """Serializable object that requires client reference.""" - _client: "Client" = field(metadata=no_export_meta) + _client: "Client" = attrs.field(repr=False, metadata=no_export_meta) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -34,12 +35,11 @@ def from_list(cls: Type[T], datas: List[Dict[str, Any]], client: "Client") -> Li def update_from_dict(self, data) -> T: data = self._process_dict(data, self._client) for key, value in self._filter_kwargs(data, self._get_keys()).items(): - # todo improve setattr(self, key, value) return self -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class DiscordObject(SnowflakeObject, ClientObject): pass diff --git a/naff/models/discord/channel.py b/naff/models/discord/channel.py index 3e60220d4..70dfa4583 100644 --- a/naff/models/discord/channel.py +++ b/naff/models/discord/channel.py @@ -1,26 +1,26 @@ import time +from asyncio import QueueEmpty from collections import namedtuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable import attrs import naff.models as models - -from naff.client.const import MISSING, DISCORD_EPOCH, Absent, logger +from naff.client.const import Absent, DISCORD_EPOCH, MISSING from naff.client.errors import NotFound, VoiceNotConnected, TooManyChanges from naff.client.mixins.send import SendMixin from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter from naff.client.utils.misc_utils import get from naff.client.utils.serializer import to_dict, to_image_data from naff.models.discord.base import DiscordObject +from naff.models.discord.emoji import PartialEmoji from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.discord.snowflake import Snowflake_Type, to_snowflake, to_optional_snowflake, SnowflakeObject -from naff.models.misc.iterator import AsyncIterator from naff.models.discord.thread import ThreadTag -from naff.models.discord.emoji import PartialEmoji +from naff.models.misc.context_manager import Typing +from naff.models.misc.iterator import AsyncIterator from .enums import ( ChannelFlags, ChannelTypes, @@ -32,7 +32,6 @@ MessageFlags, InviteTargetTypes, ) -from naff.models.misc.context_manager import Typing if TYPE_CHECKING: from aiohttp import FormData @@ -128,7 +127,34 @@ async def fetch(self) -> List["models.Message"]: return messages -@define() +class ArchivedForumPosts(AsyncIterator): + def __init__(self, channel: "BaseChannel", limit: int = 50, before: Snowflake_Type = None) -> None: + self.channel: "BaseChannel" = channel + self.before: Snowflake_Type = before + self._more: bool = True + super().__init__(limit) + + if self.before: + self.last = self.before + + async def fetch(self) -> list["GuildForumPost"]: + if self._more: + expected = self.get_limit + + rcv = await self.channel._client.http.list_public_archived_threads( + self.channel.id, limit=expected, before=to_snowflake(self.last) if self.last else None + ) + threads = [self.channel._client.cache.place_channel_data(data) for data in rcv["threads"]] + + if not rcv: + raise QueueEmpty + + self._more = rcv.get("has_more", False) + return threads + raise QueueEmpty + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class PermissionOverwrite(SnowflakeObject, DictSerializationMixin): """ Channel Permissions Overwrite object. @@ -138,11 +164,15 @@ class PermissionOverwrite(SnowflakeObject, DictSerializationMixin): """ - type: "OverwriteTypes" = field(repr=True, converter=OverwriteTypes) + type: "OverwriteTypes" = attrs.field(repr=True, converter=OverwriteTypes) """Permission overwrite type (role or member)""" - allow: Optional["Permissions"] = field(repr=True, converter=optional_c(Permissions), kw_only=True, default=None) + allow: Optional["Permissions"] = attrs.field( + repr=True, converter=optional_c(Permissions), kw_only=True, default=None + ) """Permissions to allow""" - deny: Optional["Permissions"] = field(repr=True, converter=optional_c(Permissions), kw_only=True, default=None) + deny: Optional["Permissions"] = attrs.field( + repr=True, converter=optional_c(Permissions), kw_only=True, default=None + ) """Permissions to deny""" @classmethod @@ -191,15 +221,17 @@ def add_denies(self, *perms: "Permissions") -> None: self.deny |= perm -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class MessageableMixin(SendMixin): - last_message_id: Optional[Snowflake_Type] = field( - default=None + last_message_id: Optional[Snowflake_Type] = attrs.field( + repr=False, default=None ) # TODO May need to think of dynamically updating this. """The id of the last message sent in this channel (may not point to an existing or valid message)""" - default_auto_archive_duration: int = field(default=AutoArchiveDuration.ONE_DAY) + default_auto_archive_duration: int = attrs.field(repr=False, default=AutoArchiveDuration.ONE_DAY) """Default duration that the clients (not the API) will use for newly created threads, in minutes, to automatically archive the thread after recent activity""" - last_pin_timestamp: Optional["models.Timestamp"] = field(default=None, converter=optional_c(timestamp_converter)) + last_pin_timestamp: Optional["models.Timestamp"] = attrs.field( + repr=False, default=None, converter=optional_c(timestamp_converter) + ) """When the last pinned message was pinned. This may be None when a message is not pinned.""" async def _send_http_request( @@ -434,7 +466,7 @@ def typing(self) -> Typing: return Typing(self) -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class InvitableMixin: async def create_invite( self, @@ -480,7 +512,14 @@ async def create_invite( target_type = InviteTargetTypes.EMBEDDED_APPLICATION invite_data = await self._client.http.create_channel_invite( - self.id, max_age, max_uses, temporary, unique, target_type, target_user, target_application, reason + self.id, + max_age, + max_uses, + temporary, + unique, + target_user_id=target_user, + target_application_id=target_application, + reason=reason, ) return models.Invite.from_dict(invite_data, self._client) @@ -496,7 +535,7 @@ async def fetch_invites(self) -> List["models.Invite"]: return models.Invite.from_list(invites_data, self._client) -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class ThreadableMixin: async def create_thread( self, @@ -672,7 +711,7 @@ async def fetch_all_threads(self) -> "models.ThreadList": return threads -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class WebhookMixin: async def create_webhook(self, name: str, avatar: Absent[UPLOADABLE_TYPE] = MISSING) -> "models.Webhook": """ @@ -713,11 +752,11 @@ async def fetch_webhooks(self) -> List["models.Webhook"]: return [models.Webhook.from_dict(d, self._client) for d in resp] -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class BaseChannel(DiscordObject): - name: Optional[str] = field(repr=True, default=None) + name: Optional[str] = attrs.field(repr=True, default=None) """The name of the channel (1-100 characters)""" - type: Union[ChannelTypes, int] = field(repr=True, converter=ChannelTypes) + type: Union[ChannelTypes, int] = attrs.field(repr=True, converter=ChannelTypes) """The channel topic (0-1024 characters)""" @classmethod @@ -736,9 +775,15 @@ def from_dict_factory(cls, data: dict, client: "Client") -> "TYPE_ALL_CHANNEL": channel_type = data.get("type", None) channel_class = TYPE_CHANNEL_MAPPING.get(channel_type, None) if not channel_class: - logger.error(f"Unsupported channel type for {data} ({channel_type}).") + client.logger.error(f"Unsupported channel type for {data} ({channel_type}).") channel_class = BaseChannel + if channel_class == GuildPublicThread: + # attempt to determine if this thread is a forum post (thanks discord) + parent_channel = client.cache.get_channel(data["parent_id"]) + if parent_channel and parent_channel.type == ChannelTypes.GUILD_FORUM: + channel_class = GuildForumPost + return channel_class.from_dict(data, client) @property @@ -748,6 +793,7 @@ def mention(self) -> str: async def edit( self, + *, name: Absent[str] = MISSING, icon: Absent[UPLOADABLE_TYPE] = MISSING, type: Absent[ChannelTypes] = MISSING, @@ -843,9 +889,9 @@ async def delete(self, reason: Absent[Optional[str]] = MISSING) -> None: # DMs -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class DMChannel(BaseChannel, MessageableMixin): - recipients: List["models.User"] = field(factory=list) + recipients: List["models.User"] = attrs.field(repr=False, factory=list) """The users of the DM that will receive messages sent""" @classmethod @@ -864,7 +910,7 @@ def members(self) -> List["models.User"]: return self.recipients -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class DM(DMChannel): @property def recipient(self) -> "models.User": @@ -872,15 +918,16 @@ def recipient(self) -> "models.User": return self.recipients[0] -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class DMGroup(DMChannel): - owner_id: Snowflake_Type = field(repr=True) + owner_id: Snowflake_Type = attrs.field(repr=True) """id of the creator of the group DM""" - application_id: Optional[Snowflake_Type] = field(default=None) + application_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """Application id of the group DM creator if it is bot-created""" async def edit( self, + *, name: Absent[str] = MISSING, icon: Absent[UPLOADABLE_TYPE] = MISSING, reason: Absent[str] = MISSING, @@ -937,18 +984,18 @@ async def remove_recipient(self, user: Union["models.User", Snowflake_Type]) -> # Guild -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class GuildChannel(BaseChannel): - position: Optional[int] = field(default=0) + position: Optional[int] = attrs.field(repr=False, default=0) """Sorting position of the channel""" - nsfw: bool = field(default=False) + nsfw: bool = attrs.field(repr=False, default=False) """Whether the channel is nsfw""" - parent_id: Optional[Snowflake_Type] = field(default=None, converter=optional_c(to_snowflake)) + parent_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the parent category for a channel (each parent category can contain up to 50 channels)""" - permission_overwrites: list[PermissionOverwrite] = field(factory=list) + permission_overwrites: list[PermissionOverwrite] = attrs.field(repr=False, factory=list) """A list of the overwritten permissions for the members and roles""" - _guild_id: Optional[Snowflake_Type] = field(default=None, converter=optional_c(to_snowflake)) + _guild_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) @property def guild(self) -> "models.Guild": @@ -1082,7 +1129,12 @@ async def edit_permission(self, overwrite: PermissionOverwrite, reason: Optional reason: The reason for this change """ await self._client.http.edit_channel_permission( - self.id, overwrite.id, overwrite.allow, overwrite.deny, overwrite.type, reason + self.id, + overwrite_id=overwrite.id, + allow=overwrite.allow, + deny=overwrite.deny, + perm_type=overwrite.type, + reason=reason, ) async def delete_permission( @@ -1255,7 +1307,7 @@ async def clone(self, name: Optional[str] = None, reason: Absent[Optional[str]] ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildCategory(GuildChannel): @property def channels(self) -> List["TYPE_GUILD_CHANNEL"]: @@ -1288,6 +1340,7 @@ def news_channels(self) -> List["GuildNews"]: async def edit( self, + *, name: Absent[str] = MISSING, position: Absent[int] = MISSING, permission_overwrites: Absent[ @@ -1522,13 +1575,14 @@ async def create_stage_channel( ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildNews(GuildChannel, MessageableMixin, InvitableMixin, ThreadableMixin, WebhookMixin): - topic: Optional[str] = field(default=None) + topic: Optional[str] = attrs.field(repr=False, default=None) """The channel topic (0-1024 characters)""" async def edit( self, + *, name: Absent[str] = MISSING, position: Absent[int] = MISSING, permission_overwrites: Absent[ @@ -1611,15 +1665,16 @@ async def create_thread_from_message( ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildText(GuildChannel, MessageableMixin, InvitableMixin, ThreadableMixin, WebhookMixin): - topic: Optional[str] = field(default=None) + topic: Optional[str] = attrs.field(repr=False, default=None) """The channel topic (0-1024 characters)""" - rate_limit_per_user: int = field(default=0) + rate_limit_per_user: int = attrs.field(repr=False, default=0) """Amount of seconds a user has to wait before sending another message (0-21600)""" async def edit( self, + *, name: Absent[str] = MISSING, position: Absent[int] = MISSING, permission_overwrites: Absent[ @@ -1751,34 +1806,38 @@ async def create_thread_from_message( # Guild Threads -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class ThreadChannel(BaseChannel, MessageableMixin, WebhookMixin): - parent_id: Snowflake_Type = field(default=None, converter=optional_c(to_snowflake)) + parent_id: Snowflake_Type = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the text channel this thread was created""" - owner_id: Snowflake_Type = field(default=None, converter=optional_c(to_snowflake)) + owner_id: Snowflake_Type = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the creator of the thread""" - topic: Optional[str] = field(default=None) + topic: Optional[str] = attrs.field(repr=False, default=None) """The thread topic (0-1024 characters)""" - message_count: int = field(default=0) + message_count: int = attrs.field(repr=False, default=0) """An approximate count of messages in a thread, stops counting at 50""" - member_count: int = field(default=0) + member_count: int = attrs.field(repr=False, default=0) """An approximate count of users in a thread, stops counting at 50""" - archived: bool = field(default=False) + archived: bool = attrs.field(repr=False, default=False) """Whether the thread is archived""" - auto_archive_duration: int = field( - default=attrs.Factory(lambda self: self.default_auto_archive_duration, takes_self=True) + auto_archive_duration: int = attrs.field( + repr=False, default=attrs.Factory(lambda self: self.default_auto_archive_duration, takes_self=True) ) """Duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080""" - locked: bool = field(default=False) + locked: bool = attrs.field(repr=False, default=False) """Whether the thread is locked""" - archive_timestamp: Optional["models.Timestamp"] = field(default=None, converter=optional_c(timestamp_converter)) + archive_timestamp: Optional["models.Timestamp"] = attrs.field( + repr=False, default=None, converter=optional_c(timestamp_converter) + ) """Timestamp when the thread's archive status was last changed, used for calculating recent activity""" - create_timestamp: Optional["models.Timestamp"] = field(default=None, converter=optional_c(timestamp_converter)) + create_timestamp: Optional["models.Timestamp"] = attrs.field( + repr=False, default=None, converter=optional_c(timestamp_converter) + ) """Timestamp when the thread was created""" - flags: ChannelFlags = field(default=ChannelFlags.NONE, converter=ChannelFlags) + flags: ChannelFlags = attrs.field(repr=False, default=ChannelFlags.NONE, converter=ChannelFlags) """Flags for the thread""" - _guild_id: Snowflake_Type = field(default=None, converter=optional_c(to_snowflake)) + _guild_id: Snowflake_Type = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -1817,6 +1876,25 @@ def permission_overwrites(self) -> List["PermissionOverwrite"]: """The permission overwrites for this channel.""" return [] + def permissions_for(self, instance: Snowflake_Type) -> Permissions: + """ + Calculates permissions for an instance + + Args: + instance: Member or Role instance (or its ID) + + Returns: + Permissions data + + Raises: + ValueError: If could not find any member or role by given ID + RuntimeError: If given instance is from another guild + + """ + if self.parent_channel: + return self.parent_channel.permissions_for(instance) + return Permissions.NONE + async def fetch_members(self) -> List["models.ThreadMember"]: """Get the members that have access to this thread.""" members_data = await self._client.http.list_thread_members(self.id) @@ -1865,10 +1943,11 @@ async def archive(self, locked: bool = False, reason: Absent[str] = MISSING) -> return await super().edit(locked=locked, archived=True, reason=reason) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildNewsThread(ThreadChannel): async def edit( self, + *, name: Absent[str] = MISSING, archived: Absent[bool] = MISSING, auto_archive_duration: Absent[AutoArchiveDuration] = MISSING, @@ -1903,13 +1982,67 @@ async def edit( ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildPublicThread(ThreadChannel): + async def edit( + self, + *, + name: Absent[str] = MISSING, + archived: Absent[bool] = MISSING, + auto_archive_duration: Absent[AutoArchiveDuration] = MISSING, + locked: Absent[bool] = MISSING, + rate_limit_per_user: Absent[int] = MISSING, + flags: Absent[Union[int, ChannelFlags]] = MISSING, + reason: Absent[str] = MISSING, + **kwargs, + ) -> "GuildPublicThread": + """ + Edit this thread. + + Args: + name: 1-100 character channel name + archived: whether the thread is archived + auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080 + locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it + rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600) + flags: channel flags for forum threads + reason: The reason for this change + + Returns: + The edited thread channel object. + """ + return await super().edit( + name=name, + archived=archived, + auto_archive_duration=auto_archive_duration, + locked=locked, + rate_limit_per_user=rate_limit_per_user, + reason=reason, + flags=flags, + **kwargs, + ) - _applied_tags: List[Snowflake_Type] = field(factory=list) + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class GuildForumPost(GuildPublicThread): + """ + A forum post + + !!! note + This model is an abstraction of the api - In reality all posts are GuildPublicThread + """ + + _applied_tags: list[Snowflake_Type] = attrs.field(repr=False, factory=list) + + @classmethod + def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: + data = super()._process_dict(data, client) + data["_applied_tags"] = data.pop("applied_tags") if "applied_tags" in data else [] + return data async def edit( self, + *, name: Absent[str] = MISSING, archived: Absent[bool] = MISSING, auto_archive_duration: Absent[AutoArchiveDuration] = MISSING, @@ -1919,18 +2052,18 @@ async def edit( flags: Absent[Union[int, ChannelFlags]] = MISSING, reason: Absent[str] = MISSING, **kwargs, - ) -> "GuildPublicThread": + ) -> "GuildForumPost": """ Edit this thread. Args: name: 1-100 character channel name archived: whether the thread is archived - applied_tags: list of tags to apply to a forum post (!!! This is for forum threads only) + applied_tags: list of tags to apply auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080 locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600) - flags: channel flags for forum threads + flags: channel flags to apply reason: The reason for this change Returns: @@ -1953,38 +2086,54 @@ async def edit( @property def applied_tags(self) -> list[ThreadTag]: - """ - The tags applied to this thread. - - !!! note - This is only on forum threads. - - """ + """The tags applied to this thread.""" if not isinstance(self.parent_channel, GuildForum): raise AttributeError("This is only available on forum threads.") return [tag for tag in self.parent_channel.available_tags if str(tag.id) in self._applied_tags] @property def initial_post(self) -> Optional["Message"]: + """The initial message posted by the OP.""" + if not isinstance(self.parent_channel, GuildForum): + raise AttributeError("This is only available on forum threads.") + return self.get_message(self.id) + + @property + def pinned(self) -> bool: + """Whether this thread is pinned.""" + return ChannelFlags.PINNED in self.flags + + async def pin(self, reason: Absent[str] = MISSING) -> None: """ - The initial message posted by the OP. + Pin this thread. - !!! note - This is only on forum threads. + Args: + reason: The reason for this pin """ - if not isinstance(self.parent_channel, GuildForum): - raise AttributeError("This is only available on forum threads.") - return self.get_message(self.id) + flags = self.flags | ChannelFlags.PINNED + await self.edit(flags=flags, reason=reason) + + async def unpin(self, reason: Absent[str] = MISSING) -> None: + """ + Unpin this thread. + + Args: + reason: The reason for this unpin + + """ + flags = self.flags & ~ChannelFlags.PINNED + await self.edit(flags=flags, reason=reason) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildPrivateThread(ThreadChannel): - invitable: bool = field(default=False) + invitable: bool = attrs.field(repr=False, default=False) """Whether non-moderators can add other non-moderators to a thread""" async def edit( self, + *, name: Absent[str] = MISSING, archived: Absent[bool] = MISSING, auto_archive_duration: Absent[AutoArchiveDuration] = MISSING, @@ -2026,20 +2175,25 @@ async def edit( # Guild Voices -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False, kw_only=True) class VoiceChannel(GuildChannel): # May not be needed, can be directly just GuildVoice. - bitrate: int = field() + bitrate: int = attrs.field( + repr=False, + ) """The bitrate (in bits) of the voice channel""" - user_limit: int = field() + user_limit: int = attrs.field( + repr=False, + ) """The user limit of the voice channel""" - rtc_region: str = field(default="auto") + rtc_region: str = attrs.field(repr=False, default="auto") """Voice region id for the voice channel, automatic when set to None""" - video_quality_mode: Union[VideoQualityModes, int] = field(default=VideoQualityModes.AUTO) + video_quality_mode: Union[VideoQualityModes, int] = attrs.field(repr=False, default=VideoQualityModes.AUTO) """The camera video quality mode of the voice channel, 1 when not present""" - _voice_member_ids: list[Snowflake_Type] = field(factory=list) + _voice_member_ids: list[Snowflake_Type] = attrs.field(repr=False, factory=list) async def edit( self, + *, name: Absent[str] = MISSING, position: Absent[int] = MISSING, permission_overwrites: Absent[ @@ -2134,14 +2288,14 @@ async def disconnect(self) -> None: raise VoiceNotConnected -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildVoice(VoiceChannel, InvitableMixin, MessageableMixin): pass -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildStageVoice(GuildVoice): - stage_instance: "models.StageInstance" = field(default=MISSING) + stage_instance: "models.StageInstance" = attrs.field(repr=False, default=MISSING) """The stage instance that this voice channel belongs to""" # todo: Listeners and speakers properties (needs voice state caching) @@ -2198,11 +2352,11 @@ async def close_stage(self, reason: Absent[Optional[str]] = MISSING) -> None: await self.stage_instance.delete(reason=reason) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildForum(GuildChannel): - available_tags: List[ThreadTag] = field(factory=list) + available_tags: List[ThreadTag] = attrs.field(repr=False, factory=list) """A list of tags available to assign to threads""" - last_message_id: Optional[Snowflake_Type] = field(default=None) + last_message_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) # TODO: Implement "template" once the API supports them @classmethod @@ -2218,7 +2372,7 @@ async def create_post( self, name: str, content: str | None, - applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag"]]] = MISSING, + applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag", str]]] = MISSING, *, auto_archive_duration: AutoArchiveDuration = AutoArchiveDuration.ONE_DAY, rate_limit_per_user: Absent[int] = MISSING, @@ -2233,7 +2387,7 @@ async def create_post( file: Optional["UPLOADABLE_TYPE"] = None, tts: bool = False, reason: Absent[str] = MISSING, - ) -> "GuildPublicThread": + ) -> "GuildForumPost": """ Create a post within this channel. @@ -2254,10 +2408,23 @@ async def create_post( reason: The reason for creating this post Returns: - A GuildPublicThread object representing the created post. + A GuildForumPost object representing the created post. """ if applied_tags != MISSING: - applied_tags = [str(tag.id) if isinstance(tag, ThreadTag) else str(tag) for tag in applied_tags] + processed = [] + for tag in applied_tags: + if isinstance(tag, ThreadTag): + tag = tag.id + elif isinstance(tag, (str, int)): + tag = self.get_tag(tag, case_insensitive=True) + if not tag: + continue + tag = tag.id + elif isinstance(tag, dict): + tag = tag["id"] + processed.append(tag) + + applied_tags = processed message_payload = models.discord.message.process_message_payload( content=content, @@ -2280,7 +2447,86 @@ async def create_post( ) return self._client.cache.place_channel_data(data) - async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, str]) -> "ThreadTag": + async def fetch_posts(self) -> List["GuildForumPost"]: + """ + Requests all active posts within this channel. + + Returns: + A list of GuildForumPost objects representing the posts. + """ + # I can guarantee this endpoint will need to be converted to an async iterator eventually + data = await self._client.http.list_active_threads(self._guild_id) + threads = [self._client.cache.place_channel_data(post_data) for post_data in data["threads"]] + + return [thread for thread in threads if thread.parent_id == self.id] + + def get_posts(self, *, exclude_archived: bool = True) -> List["GuildForumPost"]: + """ + List all, cached, active posts within this channel. + + Args: + exclude_archived: Whether to exclude archived posts from the response + + Returns: + A list of GuildForumPost objects representing the posts. + """ + out = [thread for thread in self.guild.threads if thread.parent_id == self.id] + if exclude_archived: + return [thread for thread in out if not thread.archived] + return out + + def archived_posts(self, limit: int = 0, before: Snowflake_Type | None = None) -> ArchivedForumPosts: + """An async iterator for all archived posts in this channel.""" + return ArchivedForumPosts(self, limit, before) + + async def fetch_post(self, id: "Snowflake_Type") -> "GuildForumPost": + """ + Fetch a post within this channel. + + Args: + id: The id of the post to fetch + + Returns: + A GuildForumPost object representing the post. + """ + return await self._client.fetch_channel(id) + + def get_post(self, id: "Snowflake_Type") -> "GuildForumPost": + """ + Get a post within this channel. + + Args: + id: The id of the post to get + + Returns: + A GuildForumPost object representing the post. + """ + return self._client.cache.get_channel(id) + + def get_tag(self, value: str | Snowflake_Type, *, case_insensitive: bool = False) -> Optional["ThreadTag"]: + """ + Get a tag within this channel. + + Args: + value: The name or ID of the tag to get + case_insensitive: Whether to ignore case when searching for the tag + + Returns: + A ThreadTag object representing the tag. + """ + + def maybe_insensitive(string: str) -> str: + return string.lower() if case_insensitive else string + + def predicate(tag: ThreadTag) -> Optional["ThreadTag"]: + if str(tag.id) == str(value): + return tag + if maybe_insensitive(tag.name) == maybe_insensitive(value): + return tag + + return next((tag for tag in self.available_tags if predicate(tag)), None) + + async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, str, None] = None) -> "ThreadTag": """ Create a tag for this forum. @@ -2295,15 +2541,20 @@ async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, The created tag object. """ - if isinstance(emoji, str): - emoji = PartialEmoji.from_str(emoji) - elif isinstance(emoji, dict): - emoji = PartialEmoji.from_dict(emoji) + payload = {"channel_id": self.id, "name": name} - if emoji.id: - data = await self._client.http.create_tag(self.id, name, emoji_id=emoji.id) - else: - data = await self._client.http.create_tag(self.id, name, emoji_name=emoji.name) + if emoji: + if isinstance(emoji, str): + emoji = PartialEmoji.from_str(emoji) + elif isinstance(emoji, dict): + emoji = PartialEmoji.from_dict(emoji) + + if emoji.id: + payload["emoji_id"] = emoji.id + else: + payload["emoji_name"] = emoji.name + + data = await self._client.http.create_tag(**payload) channel_data = self._client.cache.place_channel_data(data) return [tag for tag in channel_data.available_tags if tag.name == name][0] @@ -2383,6 +2634,7 @@ def process_permission_overwrites( GuildStageVoice, GuildCategory, GuildPublicThread, + GuildForumPost, GuildPrivateThread, GuildNewsThread, DM, diff --git a/naff/models/discord/color.py b/naff/models/discord/color.py index cb704f881..dd21a895e 100644 --- a/naff/models/discord/color.py +++ b/naff/models/discord/color.py @@ -3,7 +3,7 @@ from enum import Enum from random import randint -from naff.client.utils.attr_utils import define, field +import attrs __all__ = ( "COLOR_TYPES", @@ -26,9 +26,9 @@ hex_regex = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") -@define(init=False) +@attrs.define(eq=False, order=False, hash=False, init=False) class Color: - value: int = field(repr=True) + value: int = attrs.field(repr=True) """The color value as an integer.""" def __init__(self, color: COLOR_TYPES | None = None) -> None: @@ -42,7 +42,7 @@ def __init__(self, color: COLOR_TYPES | None = None) -> None: if re.match(hex_regex, color): self.hex = color else: - self.value = BrandColors[color].value # todo exception handling for better message + self.value = BrandColors[color].value else: raise TypeError diff --git a/naff/models/discord/components.py b/naff/models/discord/components.py index 1716bcf1e..eb902d300 100644 --- a/naff/models/discord/components.py +++ b/naff/models/discord/components.py @@ -5,10 +5,11 @@ from naff.client.const import SELECTS_MAX_OPTIONS, SELECT_MAX_NAME_LENGTH, ACTION_ROW_MAX_ITEMS, MISSING from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field, str_validator +from naff.client.utils import list_converter +from naff.client.utils.attr_utils import str_validator from naff.client.utils.serializer import export_converter from naff.models.discord.emoji import process_emoji -from naff.models.discord.enums import ButtonStyles, ComponentTypes +from naff.models.discord.enums import ButtonStyles, ComponentTypes, ChannelTypes if TYPE_CHECKING: from naff.models.discord.emoji import PartialEmoji @@ -18,7 +19,11 @@ "InteractiveComponent", "Button", "SelectOption", - "Select", + "StringSelectMenu", + "UserSelectMenu", + "RoleSelectMenu", + "MentionableSelectMenu", + "ChannelSelectMenu", "ActionRow", "process_components", "spread_to_rows", @@ -52,7 +57,7 @@ def from_dict_factory(cls, data: dict) -> "TYPE_ALL_COMPONENT": return component_class.from_dict(data) -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class InteractiveComponent(BaseComponent): """ A base interactive component class. @@ -69,7 +74,7 @@ def __eq__(self, other: Any) -> bool: return False -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Button(InteractiveComponent): """ Represents a discord ui button. @@ -84,15 +89,15 @@ class Button(InteractiveComponent): """ - style: Union[ButtonStyles, int] = field(repr=True) - label: Optional[str] = field(default=None) - emoji: Optional[Union["PartialEmoji", dict, str]] = field( + style: Union[ButtonStyles, int] = attrs.field(repr=True) + label: Optional[str] = attrs.field(repr=False, default=None) + emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( repr=True, default=None, metadata=export_converter(process_emoji) ) - custom_id: Optional[str] = field(repr=True, default=MISSING, validator=str_validator) - url: Optional[str] = field(repr=True, default=None) - disabled: bool = field(repr=True, default=False) - type: Union[ComponentTypes, int] = field( + custom_id: Optional[str] = attrs.field(repr=True, default=MISSING, validator=str_validator) + url: Optional[str] = attrs.field(repr=True, default=None) + disabled: bool = attrs.field(repr=True, default=False) + type: Union[ComponentTypes, int] = attrs.field( repr=True, default=ComponentTypes.BUTTON, init=False, on_setattr=attrs.setters.frozen ) @@ -121,7 +126,7 @@ def _check_object(self) -> None: raise TypeError("You must have at least a label or emoji on a button.") -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class SelectOption(BaseComponent): """ Represents a select option. @@ -135,13 +140,32 @@ class SelectOption(BaseComponent): """ - label: str = field(repr=True, validator=str_validator) - value: str = field(repr=True, validator=str_validator) - description: Optional[str] = field(repr=True, default=None) - emoji: Optional[Union["PartialEmoji", dict, str]] = field( + label: str = attrs.field(repr=True, validator=str_validator) + value: str = attrs.field(repr=True, validator=str_validator) + description: Optional[str] = attrs.field(repr=True, default=None) + emoji: Optional[Union["PartialEmoji", dict, str]] = attrs.field( repr=True, default=None, metadata=export_converter(process_emoji) ) - default: bool = field(repr=True, default=False) + default: bool = attrs.field(repr=True, default=False) + + @classmethod + def converter(cls, value: Any) -> "SelectOption": + if isinstance(value, SelectOption): + return value + if isinstance(value, dict): + return cls.from_dict(value) + + if isinstance(value, str): + return cls(label=value, value=value) + + try: + possible_iter = iter(value) + + return cls(label=possible_iter[0], value=possible_iter[1]) + except TypeError: + pass + + raise TypeError(f"Cannot convert {value} of type {type(value)} to a SelectOption") @label.validator def _label_validator(self, attribute: str, value: str) -> None: @@ -159,34 +183,30 @@ def _description_validator(self, attribute: str, value: str) -> None: raise ValueError("Description length must be 100 or lower.") -@define(kw_only=False) -class Select(InteractiveComponent): +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class BaseSelectMenu(InteractiveComponent): """ - Represents a select component. + Represents a select menu component Attributes: - options List[dict]: The choices in the select, max 25. custom_id str: A developer-defined identifier for the button, max 100 characters. placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) disabled bool: Disable the select and make it not intractable, default false. type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. - """ - options: List[Union[SelectOption, Dict]] = field(repr=True, factory=list) - custom_id: str = field(repr=True, factory=lambda: str(uuid.uuid4()), validator=str_validator) - placeholder: str = field(repr=True, default=None) - min_values: Optional[int] = field(repr=True, default=1) - max_values: Optional[int] = field(repr=True, default=1) - disabled: bool = field(repr=True, default=False) - type: Union[ComponentTypes, int] = field( - repr=True, default=ComponentTypes.SELECT, init=False, on_setattr=attrs.setters.frozen - ) + min_values: int = attrs.field(repr=True, default=1, kw_only=True) + max_values: int = attrs.field(repr=True, default=1, kw_only=True) + placeholder: Optional[str] = attrs.field(repr=True, default=None, kw_only=True) - def __len__(self) -> int: - return len(self.options) + # generic component attributes + disabled: bool = attrs.field(repr=True, default=False, kw_only=True) + custom_id: str = attrs.field(repr=True, factory=lambda: str(uuid.uuid4()), validator=str_validator, kw_only=True) + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.STRING_SELECT, init=False, on_setattr=attrs.setters.frozen + ) @placeholder.validator def _placeholder_validator(self, attribute: str, value: str) -> None: @@ -196,21 +216,49 @@ def _placeholder_validator(self, attribute: str, value: str) -> None: @min_values.validator def _min_values_validator(self, attribute: str, value: int) -> None: if value < 0: - raise ValueError("Select min value cannot be a negative number.") + raise ValueError("StringSelectMenu min value cannot be a negative number.") @max_values.validator def _max_values_validator(self, attribute: str, value: int) -> None: if value < 0: - raise ValueError("Select max value cannot be a negative number.") + raise ValueError("StringSelectMenu max value cannot be a negative number.") + + def _check_object(self) -> None: + super()._check_object() + if not self.custom_id: + raise TypeError("You need to have a custom id to identify the select.") + + if self.max_values < self.min_values: + raise TypeError("Selects max value cannot be less than min value.") + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class StringSelectMenu(BaseSelectMenu): + """ + Represents a string select component. + + Attributes: + options List[dict]: The choices in the select, max 25. + custom_id str: A developer-defined identifier for the button, max 100 characters. + placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. + min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) + max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) + disabled bool: Disable the select and make it not intractable, default false. + type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. + """ + + options: list[SelectOption | str] = attrs.field(repr=True, converter=list_converter(SelectOption.converter)) + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.STRING_SELECT, init=False, on_setattr=attrs.setters.frozen + ) @options.validator def _options_validator(self, attribute: str, value: List[Union[SelectOption, Dict]]) -> None: if not all(isinstance(x, (SelectOption, Dict)) for x in value): - raise ValueError("Select options must be of type `SelectOption`") + raise ValueError("StringSelectMenu options must be of type `SelectOption`") def _check_object(self) -> None: - if not self.custom_id: - raise TypeError("You need to have a custom id to identify the select.") + super()._check_object() if not self.options: raise TypeError("Selects needs to have at least 1 option.") @@ -218,34 +266,109 @@ def _check_object(self) -> None: if len(self.options) > SELECTS_MAX_OPTIONS: raise TypeError("Selects can only hold 25 options") - if self.max_values < self.min_values: - raise TypeError("Selects max value cannot be less than min value.") - - def add_option(self, option: SelectOption) -> None: - if not isinstance(option, (SelectOption, Dict)): - raise ValueError(f"Select option must be of `SelectOption` type, not {type(option)}") + def add_option(self, option: str | SelectOption) -> None: + option = SelectOption.converter(option) self.options.append(option) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class UserSelectMenu(BaseSelectMenu): + """ + Represents a user select component. + + Attributes: + custom_id str: A developer-defined identifier for the button, max 100 characters. + placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. + min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) + max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) + disabled bool: Disable the select and make it not intractable, default false. + type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. + """ + + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.USER_SELECT, init=False, on_setattr=attrs.setters.frozen + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class RoleSelectMenu(BaseSelectMenu): + """ + Represents a role select component. + + Attributes: + custom_id str: A developer-defined identifier for the button, max 100 characters. + placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. + min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) + max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) + disabled bool: Disable the select and make it not intractable, default false. + type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. + """ + + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.ROLE_SELECT, init=False, on_setattr=attrs.setters.frozen + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MentionableSelectMenu(BaseSelectMenu): + """ + Represents a mentionable select component. + + Attributes: + custom_id str: A developer-defined identifier for the button, max 100 characters. + placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. + min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) + max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) + disabled bool: Disable the select and make it not intractable, default false. + type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. + """ + + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.MENTIONABLE_SELECT, init=False, on_setattr=attrs.setters.frozen + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class ChannelSelectMenu(BaseSelectMenu): + """ + Represents a channel select component. + + Attributes: + channel_types List[ChannelTypes]: List of channel types to include in the selection + custom_id str: A developer-defined identifier for the button, max 100 characters. + placeholder str: The custom placeholder text to show if nothing is selected, max 100 characters. + min_values Optional[int]: The minimum number of items that must be chosen. (default 1, min 0, max 25) + max_values Optional[int]: The maximum number of items that can be chosen. (default 1, max 25) + disabled bool: Disable the select and make it not intractable, default false. + type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. + """ + + channel_types: list[ChannelTypes] = attrs.field(factory=list, repr=True, converter=list_converter(ChannelTypes)) + + type: Union[ComponentTypes, int] = attrs.field( + repr=True, default=ComponentTypes.CHANNEL_SELECT, init=False, on_setattr=attrs.setters.frozen + ) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ActionRow(BaseComponent): """ Represents an action row. Attributes: - components List[Union[dict, Select, Button]]: The components within this action row + components List[Union[dict, StringSelectMenu, Button]]: The components within this action row type Union[ComponentTypes, int]: The action role type number defined by discord. This cannot be modified. """ _max_items = ACTION_ROW_MAX_ITEMS - components: Sequence[Union[dict, Select, Button]] = field(repr=True, factory=list) - type: Union[ComponentTypes, int] = field( - default=ComponentTypes.ACTION_ROW, init=False, on_setattr=attrs.setters.frozen + components: Sequence[Union[dict, StringSelectMenu, Button]] = attrs.field(repr=True, factory=list) + type: Union[ComponentTypes, int] = attrs.field( + repr=False, default=ComponentTypes.ACTION_ROW, init=False, on_setattr=attrs.setters.frozen ) - def __init__(self, *components: Union[dict, Select, Button]) -> None: + def __init__(self, *components: Union[dict, StringSelectMenu, Button]) -> None: self.__attrs_init__(components) self.components = [self._component_checks(c) for c in self.components] @@ -256,7 +379,7 @@ def __len__(self) -> int: def from_dict(cls, data) -> "ActionRow": return cls(*data["components"]) - def _component_checks(self, component: Union[dict, Select, Button]) -> Union[Select, Button]: + def _component_checks(self, component: Union[dict, StringSelectMenu, Button]) -> Union[StringSelectMenu, Button]: if isinstance(component, dict): component = BaseComponent.from_dict_factory(component) @@ -270,10 +393,10 @@ def _check_object(self) -> None: if not (0 < len(self.components) <= ActionRow._max_items): raise TypeError(f"Number of components in one row should be between 1 and {ActionRow._max_items}.") - if any(x.type == ComponentTypes.SELECT for x in self.components) and len(self.components) != 1: + if any(x.type == ComponentTypes.STRING_SELECT for x in self.components) and len(self.components) != 1: raise TypeError("Action row must have only one select component and nothing else.") - def add_components(self, *components: Union[dict, Button, Select]) -> None: + def add_components(self, *components: Union[dict, Button, StringSelectMenu]) -> None: """ Add one or more component(s) to this action row. @@ -335,7 +458,7 @@ def process_components( raise ValueError(f"Invalid components: {components}") -def spread_to_rows(*components: Union[ActionRow, Button, Select], max_in_row: int = 5) -> List[ActionRow]: +def spread_to_rows(*components: Union[ActionRow, Button, StringSelectMenu], max_in_row: int = 5) -> List[ActionRow]: """ A helper function that spreads your components into `ActionRow`s of a set size. @@ -375,7 +498,7 @@ def spread_to_rows(*components: Union[ActionRow, Button, Select], max_in_row: in if component is not None: if component.type == ComponentTypes.ACTION_ROW: rows.append(component) - elif component.type == ComponentTypes.SELECT: + elif component.type == ComponentTypes.STRING_SELECT: rows.append(ActionRow(component)) if button_row: rows.append(ActionRow(*button_row)) @@ -418,10 +541,14 @@ def get_components_ids(component: Union[str, dict, list, InteractiveComponent]) raise ValueError(f"Unknown component type of {component} ({type(component)}). " f"Expected str, dict or list") -TYPE_ALL_COMPONENT = Union[ActionRow, Button, Select] +TYPE_ALL_COMPONENT = Union[ActionRow, Button, StringSelectMenu] TYPE_COMPONENT_MAPPING = { ComponentTypes.ACTION_ROW: ActionRow, ComponentTypes.BUTTON: Button, - ComponentTypes.SELECT: Select, + ComponentTypes.STRING_SELECT: StringSelectMenu, + ComponentTypes.USER_SELECT: UserSelectMenu, + ComponentTypes.CHANNEL_SELECT: ChannelSelectMenu, + ComponentTypes.ROLE_SELECT: RoleSelectMenu, + ComponentTypes.MENTIONABLE_SELECT: MentionableSelectMenu, } diff --git a/naff/models/discord/embed.py b/naff/models/discord/embed.py index c75a6ee82..88c899fff 100644 --- a/naff/models/discord/embed.py +++ b/naff/models/discord/embed.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Union +import attrs from attrs.validators import instance_of from attrs.validators import optional as v_optional @@ -12,9 +13,8 @@ EMBED_FIELD_VALUE_LENGTH, ) from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field -from naff.client.utils.attr_converters import timestamp_converter from naff.client.utils.attr_converters import optional as c_optional +from naff.client.utils.attr_converters import timestamp_converter from naff.client.utils.serializer import no_export_meta, export_converter from naff.models.discord.color import Color, process_color from naff.models.discord.enums import EmbedTypes @@ -32,7 +32,7 @@ ) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class EmbedField(DictSerializationMixin): """ Representation of an embed field. @@ -44,9 +44,13 @@ class EmbedField(DictSerializationMixin): """ - name: str = field() - value: str = field() - inline: bool = field(default=False) + name: str = attrs.field( + repr=False, + ) + value: str = attrs.field( + repr=False, + ) + inline: bool = attrs.field(repr=False, default=False) @name.validator def _name_validation(self, attribute: str, value: Any) -> None: @@ -62,7 +66,7 @@ def __len__(self) -> int: return len(self.name) + len(self.value) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class EmbedAuthor(DictSerializationMixin): """ Representation of an embed author. @@ -75,10 +79,10 @@ class EmbedAuthor(DictSerializationMixin): """ - name: Optional[str] = field(default=None) - url: Optional[str] = field(default=None) - icon_url: Optional[str] = field(default=None) - proxy_icon_url: Optional[str] = field(default=None, metadata=no_export_meta) + name: Optional[str] = attrs.field(repr=False, default=None) + url: Optional[str] = attrs.field(repr=False, default=None) + icon_url: Optional[str] = attrs.field(repr=False, default=None) + proxy_icon_url: Optional[str] = attrs.field(repr=False, default=None, metadata=no_export_meta) @name.validator def _name_validation(self, attribute: str, value: Any) -> None: @@ -89,7 +93,7 @@ def __len__(self) -> int: return len(self.name) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class EmbedAttachment(DictSerializationMixin): # thumbnail or image or video """ Representation of an attachment. @@ -102,10 +106,10 @@ class EmbedAttachment(DictSerializationMixin): # thumbnail or image or video """ - url: Optional[str] = field(default=None) - proxy_url: Optional[str] = field(default=None, metadata=no_export_meta) - height: Optional[int] = field(default=None, metadata=no_export_meta) - width: Optional[int] = field(default=None, metadata=no_export_meta) + url: Optional[str] = attrs.field(repr=False, default=None) + proxy_url: Optional[str] = attrs.field(repr=False, default=None, metadata=no_export_meta) + height: Optional[int] = attrs.field(repr=False, default=None, metadata=no_export_meta) + width: Optional[int] = attrs.field(repr=False, default=None, metadata=no_export_meta) @classmethod def _process_dict(cls, data: Dict[str, Any]) -> Dict[str, Any]: @@ -118,7 +122,7 @@ def size(self) -> tuple[Optional[int], Optional[int]]: return self.height, self.width -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class EmbedFooter(DictSerializationMixin): """ Representation of an Embed Footer. @@ -130,9 +134,11 @@ class EmbedFooter(DictSerializationMixin): """ - text: str = field() - icon_url: Optional[str] = field(default=None) - proxy_icon_url: Optional[str] = field(default=None, metadata=no_export_meta) + text: str = attrs.field( + repr=False, + ) + icon_url: Optional[str] = attrs.field(repr=False, default=None) + proxy_icon_url: Optional[str] = attrs.field(repr=False, default=None, metadata=no_export_meta) @classmethod def converter(cls, ingest: Union[dict, str, "EmbedFooter"]) -> "EmbedFooter": @@ -154,7 +160,7 @@ def __len__(self) -> int: return len(self.text) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class EmbedProvider(DictSerializationMixin): """ Represents an embed's provider. @@ -168,50 +174,56 @@ class EmbedProvider(DictSerializationMixin): """ - name: Optional[str] = field(default=None) - url: Optional[str] = field(default=None) + name: Optional[str] = attrs.field(repr=False, default=None) + url: Optional[str] = attrs.field(repr=False, default=None) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Embed(DictSerializationMixin): """Represents a discord embed object.""" - title: Optional[str] = field(default=None, repr=True) + title: Optional[str] = attrs.field(default=None, repr=True) """The title of the embed""" - description: Optional[str] = field(default=None, repr=True) + description: Optional[str] = attrs.field(default=None, repr=True) """The description of the embed""" - color: Optional[Union[Color, dict, tuple, list, str, int]] = field( + color: Optional[Union[Color, dict, tuple, list, str, int]] = attrs.field( default=None, repr=True, metadata=export_converter(process_color) ) """The colour of the embed""" - url: Optional[str] = field(default=None, validator=v_optional(instance_of(str)), repr=True) + url: Optional[str] = attrs.field(default=None, validator=v_optional(instance_of(str)), repr=True) """The url the embed should direct to when clicked""" - timestamp: Optional[Timestamp] = field( + timestamp: Optional[Timestamp] = attrs.field( default=None, converter=c_optional(timestamp_converter), validator=v_optional(instance_of((datetime, float, int))), repr=True, ) """Timestamp of embed content""" - fields: List[EmbedField] = field(factory=list, converter=EmbedField.from_list, repr=True) + fields: List[EmbedField] = attrs.field(factory=list, converter=EmbedField.from_list, repr=True) """A list of [fields][naff.models.discord.embed.EmbedField] to go in the embed""" - author: Optional[EmbedAuthor] = field(default=None, converter=c_optional(EmbedAuthor.from_dict)) + author: Optional[EmbedAuthor] = attrs.field(repr=False, default=None, converter=c_optional(EmbedAuthor.from_dict)) """The author of the embed""" - thumbnail: Optional[EmbedAttachment] = field(default=None, converter=c_optional(EmbedAttachment.from_dict)) + thumbnail: Optional[EmbedAttachment] = attrs.field( + repr=False, default=None, converter=c_optional(EmbedAttachment.from_dict) + ) """The thumbnail of the embed""" - image: Optional[EmbedAttachment] = field(default=None, converter=c_optional(EmbedAttachment.from_dict)) + image: Optional[EmbedAttachment] = attrs.field( + repr=False, default=None, converter=c_optional(EmbedAttachment.from_dict) + ) """The image of the embed""" - video: Optional[EmbedAttachment] = field( - default=None, converter=c_optional(EmbedAttachment.from_dict), metadata=no_export_meta + video: Optional[EmbedAttachment] = attrs.field( + repr=False, default=None, converter=c_optional(EmbedAttachment.from_dict), metadata=no_export_meta ) """The video of the embed, only used by system embeds""" - footer: Optional[EmbedFooter] = field(default=None, converter=c_optional(EmbedFooter.converter)) + footer: Optional[EmbedFooter] = attrs.field(repr=False, default=None, converter=c_optional(EmbedFooter.converter)) """The footer of the embed""" - provider: Optional[EmbedProvider] = field( - default=None, converter=c_optional(EmbedProvider.from_dict), metadata=no_export_meta + provider: Optional[EmbedProvider] = attrs.field( + repr=False, default=None, converter=c_optional(EmbedProvider.from_dict), metadata=no_export_meta ) """The provider of the embed, only used for system embeds""" - type: EmbedTypes = field(default=EmbedTypes.RICH, converter=c_optional(EmbedTypes), metadata=no_export_meta) + type: EmbedTypes = attrs.field( + repr=False, default=EmbedTypes.RICH, converter=c_optional(EmbedTypes), metadata=no_export_meta + ) @title.validator def _name_validation(self, attribute: str, value: Any) -> None: @@ -341,6 +353,25 @@ def add_field(self, name: str, value: Any, inline: bool = False) -> None: self.fields.append(EmbedField(name, str(value), inline)) self._fields_validation("fields", self.fields) + def add_fields(self, *fields: EmbedField | str | dict) -> None: + """ + Add multiple fields to the embed. + + Args: + fields: The fields to add + + """ + for _field in fields: + if isinstance(_field, EmbedField): + self.fields.append(_field) + self._fields_validation("fields", self.fields) + elif isinstance(_field, str): + self.add_field(_field, _field) + elif isinstance(_field, dict): + self.add_field(**_field) + else: + raise TypeError(f"Expected EmbedField, str or dict, got {type(_field).__name__}") + def process_embeds(embeds: Optional[Union[List[Union[Embed, Dict]], Union[Embed, Dict]]]) -> Optional[List[dict]]: """ diff --git a/naff/models/discord/emoji.py b/naff/models/discord/emoji.py index 023ac71d4..bc16bb130 100644 --- a/naff/models/discord/emoji.py +++ b/naff/models/discord/emoji.py @@ -1,8 +1,12 @@ import re +import string +import unicodedata from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import attrs +import emoji + from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import list_converter from naff.client.utils.attr_converters import optional from naff.client.utils.serializer import dict_filter_none, no_export_meta @@ -19,23 +23,24 @@ __all__ = ("PartialEmoji", "CustomEmoji", "process_emoji_req_format", "process_emoji") emoji_regex = re.compile(r"?") +unicode_emoji_reg = re.compile(r"[^\w\s,’‘“”â€Ļ–—â€ĸâ—Ļâ€Ŗ⁃⁎⁏⁒⁓âēâģâŧâŊ⁞âŋ₊₋₌₍₎]") -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class PartialEmoji(SnowflakeObject, DictSerializationMixin): """Represent a basic ("partial") emoji used in discord.""" - id: Optional["Snowflake_Type"] = field( + id: Optional["Snowflake_Type"] = attrs.field( repr=True, default=None, converter=optional(to_snowflake) ) # can be None for Standard Emoji """The custom emoji id. Leave empty if you are using standard unicode emoji.""" - name: Optional[str] = field(repr=True, default=None) + name: Optional[str] = attrs.field(repr=True, default=None) """The custom emoji name, or standard unicode emoji in string""" - animated: bool = field(repr=True, default=False) + animated: bool = attrs.field(repr=True, default=False) """Whether this emoji is animated""" @classmethod - def from_str(cls, emoji_str: str) -> "PartialEmoji": + def from_str(cls, emoji_str: str, *, language: str = "alias") -> Optional["PartialEmoji"]: """ Generate a PartialEmoji from a discord Emoji string representation, or unicode emoji. @@ -45,9 +50,11 @@ def from_str(cls, emoji_str: str) -> "PartialEmoji": a:emoji_name:emoji_id 👋 + :wave: Args: emoji_str: The string representation an emoji + language: The language to use for the unicode emoji parsing Returns: A PartialEmoji object @@ -61,10 +68,29 @@ def from_str(cls, emoji_str: str) -> "PartialEmoji": parsed = tuple(filter(None, parsed[0])) if len(parsed) == 3: return cls(name=parsed[1], id=parsed[2], animated=True) - else: + elif len(parsed) == 2: return cls(name=parsed[0], id=parsed[1]) + else: + _name = emoji.emojize(emoji_str, language=language) + _emoji_list = emoji.distinct_emoji_list(_name) + if _emoji_list: + return cls(name=_emoji_list[0]) else: - return cls(name=emoji_str) + # check if it's a unicode emoji + _emoji_list = emoji.distinct_emoji_list(emoji_str) + if _emoji_list: + return cls(name=_emoji_list[0]) + + # the emoji lib handles *most* emoji, however there are certain ones that it misses + # this acts as a fallback check + if matches := unicode_emoji_reg.search(emoji_str): + match = matches.group() + + # the regex will match certain special characters, so this acts as a final failsafe + if match not in string.printable: + if unicodedata.category(match) == "So": + return cls(name=match) + return None def __str__(self) -> str: s = self.req_format @@ -88,22 +114,24 @@ def req_format(self) -> str: return self.name -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class CustomEmoji(PartialEmoji, ClientObject): """Represent a custom emoji in a guild with all its properties.""" - _client: "Client" = field(metadata=no_export_meta) + _client: "Client" = attrs.field(repr=False, metadata=no_export_meta) - require_colons: bool = field(default=False) + require_colons: bool = attrs.field(repr=False, default=False) """Whether this emoji must be wrapped in colons""" - managed: bool = field(default=False) + managed: bool = attrs.field(repr=False, default=False) """Whether this emoji is managed""" - available: bool = field(default=False) + available: bool = attrs.field(repr=False, default=False) """Whether this emoji can be used, may be false due to loss of Server Boosts.""" - _creator_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) - _role_ids: List["Snowflake_Type"] = field(factory=list, converter=optional(list_converter(to_snowflake))) - _guild_id: "Snowflake_Type" = field(default=None, converter=to_snowflake) + _creator_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) + _role_ids: List["Snowflake_Type"] = attrs.field( + repr=False, factory=list, converter=optional(list_converter(to_snowflake)) + ) + _guild_id: "Snowflake_Type" = attrs.field(repr=False, default=None, converter=to_snowflake) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -148,6 +176,7 @@ def is_usable(self) -> bool: async def edit( self, + *, name: Optional[str] = None, roles: Optional[List[Union["Snowflake_Type", "Role"]]] = None, reason: Optional[str] = None, diff --git a/naff/models/discord/enums.py b/naff/models/discord/enums.py index 56e377a76..e53556fe1 100644 --- a/naff/models/discord/enums.py +++ b/naff/models/discord/enums.py @@ -1,10 +1,9 @@ -from enum import Enum, EnumMeta, IntEnum, IntFlag, _decompose +from enum import Enum, EnumMeta, IntEnum, IntFlag from functools import reduce from operator import or_ from typing import Iterator, Tuple, TypeVar, Type -from naff.client.const import logger - +from naff.client.const import get_logger __all__ = ( "WebSocketOPCodes", @@ -44,6 +43,7 @@ "ScheduledEventType", "ScheduledEventStatus", "AuditLogEventType", + "InteractionPermissionTypes", ) @@ -61,6 +61,38 @@ def _distinct(source) -> Tuple: return (x for x in source if (x.value & (x.value - 1)) == 0 and x.value != 0) +def _decompose(flag, value): # noqa + """ + Extract all members from the value. + + Source: Python 3.10.8 source + """ + # _decompose is only called if the value is not named + not_covered = value + negative = value < 0 + members = [] + for member in flag: + member_value = member.value + if member_value and member_value & value == member_value: + members.append(member) + not_covered &= ~member_value + if not negative: + tmp = not_covered + while tmp: + flag_value = 2 ** (tmp.bit_length() - 1) + if flag_value in flag._value2member_map_: + members.append(flag._value2member_map_[flag_value]) + not_covered &= ~flag_value + tmp &= ~flag_value + if not members and value in flag._value2member_map_: + members.append(flag._value2member_map_[value]) + members.sort(key=lambda m: m._value_, reverse=True) + if len(members) > 1 and members[0].value == value: + # we have the breakdown, don't need the value member itself + members.pop(0) + return members, not_covered + + class DistinctFlag(EnumMeta): def __iter__(cls) -> Iterator: yield from _distinct(super().__iter__()) @@ -83,7 +115,7 @@ def __iter__(self) -> Iterator: def _log_type_mismatch(cls, value) -> None: - logger.error( + get_logger().error( f"Class `{cls.__name__}` received an invalid and unexpected value `{value}`. Please update NAFF or report this issue on GitHub - https://github.com/NAFTeam/NAFF/issues" ) @@ -203,7 +235,7 @@ def new( typing=False, privileged=False, non_privileged=False, - default=True, + default=False, all=False, ) -> "Intents": """Set your desired intents.""" @@ -258,6 +290,8 @@ class UserFlags(DiscordIntFlag): # type: ignore """A user who is suspected of spamming""" DISABLE_PREMIUM = 1 << 21 """Nitro features disabled for this user. Only used by Discord Staff for testing""" + ACTIVE_DEVELOPER = 1 << 22 + """This user is an active developer""" # Shortcuts/grouping/aliases HYPESQUAD = HOUSE_BRAVERY | HOUSE_BRILLIANCE | HOUSE_BALANCE @@ -542,10 +576,18 @@ class ComponentTypes(CursedIntEnum): """Container for other components""" BUTTON = 2 """Button object""" - SELECT = 3 - """Select menu for picking from choices""" + STRING_SELECT = 3 + """Select menu for picking from text choices""" INPUT_TEXT = 4 """Text input object""" + USER_SELECT = 5 + """Select menu for picking from users""" + ROLE_SELECT = 6 + """Select menu for picking from roles""" + MENTIONABLE_SELECT = 7 + """Select menu for picking from mentionable objects""" + CHANNEL_SELECT = 8 + """Select menu for picking from channels""" class CommandTypes(CursedIntEnum): @@ -569,6 +611,14 @@ class InteractionTypes(CursedIntEnum): MODAL_RESPONSE = 5 +class InteractionPermissionTypes(CursedIntEnum): + """The type of interaction permission received by discord.""" + + ROLE = 1 + USER = 2 + CHANNEL = 3 + + class ButtonStyles(CursedIntEnum): """The styles of buttons supported.""" diff --git a/naff/models/discord/file.py b/naff/models/discord/file.py index 55fa9671c..7f1465361 100644 --- a/naff/models/discord/file.py +++ b/naff/models/discord/file.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import BinaryIO, Optional, Union -from naff.client.utils.attr_utils import define, field +import attrs __all__ = ("File", "open_file", "UPLOADABLE_TYPE") -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class File: """ Representation of a file. @@ -16,9 +16,9 @@ class File: """ - file: Union["IOBase", BinaryIO, "Path", str] = field(repr=True) + file: Union["IOBase", BinaryIO, "Path", str] = attrs.field(repr=True) """Location of file to send or the bytes.""" - file_name: Optional[str] = field(repr=True, default=None) + file_name: Optional[str] = attrs.field(repr=True, default=None) """Set a filename that will be displayed when uploaded to discord. If you leave this empty, the file will be called `file` by default""" def open_file(self) -> BinaryIO: diff --git a/naff/models/discord/guild.py b/naff/models/discord/guild.py index 9e6f2bbda..8b62734c8 100644 --- a/naff/models/discord/guild.py +++ b/naff/models/discord/guild.py @@ -6,15 +6,19 @@ from typing import List, Optional, Union, Set, Dict, Any, TYPE_CHECKING from warnings import warn +import attrs + import naff.models as models -from naff.client.const import MISSING, PREMIUM_GUILD_LIMITS, logger, Absent +from naff.client.const import Absent, MISSING, PREMIUM_GUILD_LIMITS from naff.client.errors import EventLocationNotProvided, NotFound from naff.client.mixins.serialization import DictSerializationMixin from naff.client.utils.attr_converters import optional from naff.client.utils.attr_converters import timestamp_converter -from naff.client.utils.attr_utils import define, field, docs +from naff.client.utils.attr_utils import docs from naff.client.utils.deserialise_app_cmds import deserialize_app_cmds from naff.client.utils.serializer import to_image_data, no_export_meta +from naff.models.discord.app_perms import CommandPermissions, ApplicationCommandPermission +from naff.models.discord.auto_mod import AutoModRule, BaseAction, BaseTrigger from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.misc.iterator import AsyncIterator from .base import DiscordObject, ClientObject @@ -34,7 +38,6 @@ AutoModEvent, AutoModTriggerType, ) -from naff.models.discord.auto_mod import AutoModRule, BaseAction, BaseTrigger from .snowflake import to_snowflake, Snowflake_Type, to_optional_snowflake, to_snowflake_list if TYPE_CHECKING: @@ -59,7 +62,7 @@ ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildBan: reason: Optional[str] """The reason for the ban""" @@ -67,20 +70,20 @@ class GuildBan: """The banned user""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseGuild(DiscordObject): - name: str = field(repr=True) + name: str = attrs.field(repr=True) """Name of guild. (2-100 characters, excluding trailing and leading whitespace)""" - description: Optional[str] = field(repr=True, default=None) + description: Optional[str] = attrs.field(repr=True, default=None) """The description for the guild, if the guild is discoverable""" - icon: Optional["models.Asset"] = field(default=None) + icon: Optional["models.Asset"] = attrs.field(repr=False, default=None) """Icon image asset""" - splash: Optional["models.Asset"] = field(default=None) + splash: Optional["models.Asset"] = attrs.field(repr=False, default=None) """Splash image asset""" - discovery_splash: Optional["models.Asset"] = field(default=None) + discovery_splash: Optional["models.Asset"] = attrs.field(repr=False, default=None) """Discovery splash image. Only present for guilds with the "DISCOVERABLE" feature.""" - features: List[str] = field(factory=list) + features: List[str] = attrs.field(repr=False, factory=list) """The features of this guild""" @classmethod @@ -96,23 +99,25 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildWelcome(ClientObject): - description: Optional[str] = field(default=None, metadata=docs("Welcome Screen server description")) - welcome_channels: List["models.GuildWelcomeChannel"] = field( - metadata=docs("List of Welcome Channel objects, up to 5") + description: Optional[str] = attrs.field( + repr=False, default=None, metadata=docs("Welcome Screen server description") + ) + welcome_channels: List["models.GuildWelcomeChannel"] = attrs.field( + repr=False, metadata=docs("List of Welcome Channel objects, up to 5") ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildPreview(BaseGuild): """A partial guild object.""" - emoji: list["models.PartialEmoji"] = field(factory=list) + emoji: list["models.PartialEmoji"] = attrs.field(repr=False, factory=list) """A list of custom emoji from this guild""" - approximate_member_count: int = field(default=0) + approximate_member_count: int = attrs.field(repr=False, default=0) """Approximate number of members in this guild""" - approximate_presence_count: int = field(default=0) + approximate_presence_count: int = attrs.field(repr=False, default=0) """Approximate number of online members in this guild""" @classmethod @@ -140,85 +145,92 @@ async def fetch(self) -> list: raise QueueEmpty -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Guild(BaseGuild): """Guilds in Discord represent an isolated collection of users and channels, and are often referred to as "servers" in the UI.""" - unavailable: bool = field(default=False) + unavailable: bool = attrs.field(repr=False, default=False) """True if this guild is unavailable due to an outage.""" - # owner: bool = field(default=False) # we get this from api but it's kinda useless to store - afk_channel_id: Optional[Snowflake_Type] = field(default=None) + # owner: bool = attrs.field(repr=False, default=False) # we get this from api but it's kinda useless to store + afk_channel_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """The channel id for afk.""" - afk_timeout: Optional[int] = field(default=None) + afk_timeout: Optional[int] = attrs.field(repr=False, default=None) """afk timeout in seconds.""" - widget_enabled: bool = field(default=False) + widget_enabled: bool = attrs.field(repr=False, default=False) """True if the server widget is enabled.""" - widget_channel_id: Optional[Snowflake_Type] = field(default=None) + widget_channel_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """The channel id that the widget will generate an invite to, or None if set to no invite.""" - verification_level: Union[VerificationLevels, int] = field(default=VerificationLevels.NONE) + verification_level: Union[VerificationLevels, int] = attrs.field(repr=False, default=VerificationLevels.NONE) """The verification level required for the guild.""" - default_message_notifications: Union[DefaultNotificationLevels, int] = field( - default=DefaultNotificationLevels.ALL_MESSAGES + default_message_notifications: Union[DefaultNotificationLevels, int] = attrs.field( + repr=False, default=DefaultNotificationLevels.ALL_MESSAGES ) """The default message notifications level.""" - explicit_content_filter: Union[ExplicitContentFilterLevels, int] = field( - default=ExplicitContentFilterLevels.DISABLED + explicit_content_filter: Union[ExplicitContentFilterLevels, int] = attrs.field( + repr=False, default=ExplicitContentFilterLevels.DISABLED ) """The explicit content filter level.""" - mfa_level: Union[MFALevels, int] = field(default=MFALevels.NONE) + mfa_level: Union[MFALevels, int] = attrs.field(repr=False, default=MFALevels.NONE) """The required MFA (Multi Factor Authentication) level for the guild.""" - system_channel_id: Optional[Snowflake_Type] = field(default=None) + system_channel_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """The id of the channel where guild notices such as welcome messages and boost events are posted.""" - system_channel_flags: SystemChannelFlags = field(default=SystemChannelFlags.NONE, converter=SystemChannelFlags) + system_channel_flags: SystemChannelFlags = attrs.field( + repr=False, default=SystemChannelFlags.NONE, converter=SystemChannelFlags + ) """The system channel flags.""" - rules_channel_id: Optional[Snowflake_Type] = field(default=None) + rules_channel_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """The id of the channel where Community guilds can display rules and/or guidelines.""" - joined_at: str = field(default=None, converter=optional(timestamp_converter)) + joined_at: str = attrs.field(repr=False, default=None, converter=optional(timestamp_converter)) """When this guild was joined at.""" - large: bool = field(default=False) + large: bool = attrs.field(repr=False, default=False) """True if this is considered a large guild.""" - member_count: int = field(default=0) + member_count: int = attrs.field(repr=False, default=0) """The total number of members in this guild.""" - presences: List[dict] = field(factory=list) + presences: List[dict] = attrs.field(repr=False, factory=list) """The presences of the members in the guild, will only include non-offline members if the size is greater than large threshold.""" - max_presences: Optional[int] = field(default=None) + max_presences: Optional[int] = attrs.field(repr=False, default=None) """The maximum number of presences for the guild. (None is always returned, apart from the largest of guilds)""" - max_members: Optional[int] = field(default=None) + max_members: Optional[int] = attrs.field(repr=False, default=None) """The maximum number of members for the guild.""" - vanity_url_code: Optional[str] = field(default=None) + vanity_url_code: Optional[str] = attrs.field(repr=False, default=None) """The vanity url code for the guild.""" - banner: Optional[str] = field(default=None) + banner: Optional[str] = attrs.field(repr=False, default=None) """Hash for banner image.""" - premium_tier: Optional[str] = field(default=None) + premium_tier: Optional[str] = attrs.field(repr=False, default=None) """The premium tier level. (Server Boost level)""" - premium_subscription_count: int = field(default=0) + premium_subscription_count: int = attrs.field(repr=False, default=0) """The number of boosts this guild currently has.""" - preferred_locale: str = field() + preferred_locale: str = attrs.field( + repr=False, + ) """The preferred locale of a Community guild. Used in server discovery and notices from Discord. Defaults to \"en-US\"""" - public_updates_channel_id: Optional[Snowflake_Type] = field(default=None) + public_updates_channel_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=None) """The id of the channel where admins and moderators of Community guilds receive notices from Discord.""" - max_video_channel_users: int = field(default=0) + max_video_channel_users: int = attrs.field(repr=False, default=0) """The maximum amount of users in a video channel.""" - welcome_screen: Optional["GuildWelcome"] = field(default=None) + welcome_screen: Optional["GuildWelcome"] = attrs.field(repr=False, default=None) """The welcome screen of a Community guild, shown to new members, returned in an Invite's guild object.""" - nsfw_level: Union[NSFWLevels, int] = field(default=NSFWLevels.DEFAULT) + nsfw_level: Union[NSFWLevels, int] = attrs.field(repr=False, default=NSFWLevels.DEFAULT) """The guild NSFW level.""" - stage_instances: List[dict] = field(factory=list) # TODO stage instance objects + stage_instances: List[dict] = attrs.field(repr=False, factory=list) # TODO stage instance objects """Stage instances in the guild.""" - chunked = field(factory=asyncio.Event, metadata=no_export_meta) + chunked = attrs.field(repr=False, factory=asyncio.Event, metadata=no_export_meta) """An event that is fired when this guild has been chunked""" + command_permissions: dict[Snowflake_Type, CommandPermissions] = attrs.field( + repr=False, factory=dict, metadata=no_export_meta + ) + """A cache of all command permissions for this guild""" - _owner_id: Snowflake_Type = field(converter=to_snowflake) - _channel_ids: Set[Snowflake_Type] = field(factory=set) - _thread_ids: Set[Snowflake_Type] = field(factory=set) - _member_ids: Set[Snowflake_Type] = field(factory=set) - _role_ids: Set[Snowflake_Type] = field(factory=set) - _chunk_cache: list = field(factory=list) - _channel_gui_positions: Dict[Snowflake_Type, int] = field(factory=dict) + _owner_id: Snowflake_Type = attrs.field(repr=False, converter=to_snowflake) + _channel_ids: Set[Snowflake_Type] = attrs.field(repr=False, factory=set) + _thread_ids: Set[Snowflake_Type] = attrs.field(repr=False, factory=set) + _member_ids: Set[Snowflake_Type] = attrs.field(repr=False, factory=set) + _role_ids: Set[Snowflake_Type] = attrs.field(repr=False, factory=set) + _chunk_cache: list = attrs.field(repr=False, factory=list) + _channel_gui_positions: Dict[Snowflake_Type, int] = attrs.field(repr=False, factory=dict) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: - # todo: find a away to prevent this loop from blocking the event loop data = super()._process_dict(data, client) guild_id = data["id"] @@ -314,7 +326,8 @@ async def create( @property def channels(self) -> List["models.TYPE_GUILD_CHANNEL"]: """Returns a list of channels associated with this guild.""" - return [self._client.cache.get_channel(c_id) for c_id in self._channel_ids] + channels = [self._client.cache.get_channel(c_id) for c_id in self._channel_ids] + return [c for c in channels if c] @property def threads(self) -> List["models.TYPE_THREAD_CHANNEL"]: @@ -489,6 +502,26 @@ async def fetch_channels(self) -> List["models.TYPE_VOICE_CHANNEL"]: data = await self._client.http.get_guild_channels(self.id) return [self._client.cache.place_channel_data(channel_data) for channel_data in data] + async def fetch_app_cmd_perms(self) -> dict[Snowflake_Type, "CommandPermissions"]: + """ + Fetch the application command permissions for this guild. + + Returns: + The application command permissions for this guild. + + """ + data = await self._client.http.batch_get_application_command_permissions(self._client.app.id, self.id) + + for command in data: + command_permissions = CommandPermissions(client=self._client, command_id=command["id"], guild=self) + perms = [ApplicationCommandPermission.from_dict(perm, self) for perm in command["permissions"]] + + command_permissions.update_permissions(*perms) + + self.command_permissions[int(command["id"])] = command_permissions + + return self.command_permissions + def is_owner(self, user: Snowflake_Type) -> bool: """ Whether the user is owner of the guild. @@ -528,7 +561,7 @@ async def http_chunk(self) -> None: self._client.cache.place_member_data(self.id, member) self.chunked.set() - logger.info( + self.logger.info( f"Cached {iterator.total_retrieved} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds" ) @@ -596,10 +629,10 @@ async def process_member_chunk(self, chunk: dict) -> None: self._chunk_cache = self._chunk_cache + chunk.get("members") if chunk.get("chunk_index") != chunk.get("chunk_count") - 1: - return logger.debug(f"Cached chunk of {len(chunk.get('members'))} members for {self.id}") + return self.logger.debug(f"Cached chunk of {len(chunk.get('members'))} members for {self.id}") else: members = self._chunk_cache - logger.info(f"Processing {len(members)} members for {self.id}") + self.logger.info(f"Processing {len(members)} members for {self.id}") s = time.monotonic() start_time = time.perf_counter() @@ -615,7 +648,7 @@ async def process_member_chunk(self, chunk: dict) -> None: total_time = time.perf_counter() - start_time self.chunk_cache = [] - logger.info(f"Cached members for {self.id} in {total_time:.2f} seconds") + self.logger.info(f"Cached members for {self.id} in {total_time:.2f} seconds") self.chunked.set() async def fetch_audit_log( @@ -683,6 +716,7 @@ def audit_log_history( async def edit( self, + *, name: Absent[Optional[str]] = MISSING, description: Absent[Optional[str]] = MISSING, verification_level: Absent[Optional["VerificationLevels"]] = MISSING, @@ -1972,24 +2006,34 @@ def channel_sort_func(a, b) -> int: return sorted_channels -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildTemplate(ClientObject): - code: str = field(repr=True, metadata=docs("the template code (unique ID)")) - name: str = field(repr=True, metadata=docs("the name")) - description: Optional[str] = field(default=None, metadata=docs("the description")) + code: str = attrs.field(repr=True, metadata=docs("the template code (unique ID)")) + name: str = attrs.field(repr=True, metadata=docs("the name")) + description: Optional[str] = attrs.field(repr=False, default=None, metadata=docs("the description")) - usage_count: int = field(default=0, metadata=docs("number of times this template has been used")) + usage_count: int = attrs.field(repr=False, default=0, metadata=docs("number of times this template has been used")) - creator_id: Snowflake_Type = field(metadata=docs("The ID of the user who created this template")) - creator: Optional["models.User"] = field(default=None, metadata=docs("the user who created this template")) + creator_id: Snowflake_Type = attrs.field(repr=False, metadata=docs("The ID of the user who created this template")) + creator: Optional["models.User"] = attrs.field( + repr=False, default=None, metadata=docs("the user who created this template") + ) - created_at: "models.Timestamp" = field(metadata=docs("When this template was created")) - updated_at: "models.Timestamp" = field(metadata=docs("When this template was last synced to the source guild")) + created_at: "models.Timestamp" = attrs.field(repr=False, metadata=docs("When this template was created")) + updated_at: "models.Timestamp" = attrs.field( + repr=False, metadata=docs("When this template was last synced to the source guild") + ) - source_guild_id: Snowflake_Type = field(metadata=docs("The ID of the guild this template is based on")) - guild_snapshot: "models.Guild" = field(metadata=docs("A snapshot of the guild this template contains")) + source_guild_id: Snowflake_Type = attrs.field( + repr=False, metadata=docs("The ID of the guild this template is based on") + ) + guild_snapshot: "models.Guild" = attrs.field( + repr=False, metadata=docs("A snapshot of the guild this template contains") + ) - is_dirty: bool = field(default=False, metadata=docs("Whether this template has un-synced changes")) + is_dirty: bool = attrs.field( + repr=False, default=False, metadata=docs("Whether this template has un-synced changes") + ) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -2028,48 +2072,54 @@ async def delete(self) -> None: await self._client.http.delete_guild_template(self.source_guild_id, self.code) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class GuildWelcomeChannel(ClientObject): - channel_id: Snowflake_Type = field(repr=True, metadata=docs("Welcome Channel ID")) - description: str = field(metadata=docs("Welcome Channel description")) - emoji_id: Optional[Snowflake_Type] = field( - default=None, metadata=docs("Welcome Channel emoji ID if the emoji is custom") + channel_id: Snowflake_Type = attrs.field(repr=True, metadata=docs("Welcome Channel ID")) + description: str = attrs.field(repr=False, metadata=docs("Welcome Channel description")) + emoji_id: Optional[Snowflake_Type] = attrs.field( + repr=False, default=None, metadata=docs("Welcome Channel emoji ID if the emoji is custom") ) - emoji_name: Optional[str] = field( - default=None, metadata=docs("Emoji name if custom, unicode character if standard") + emoji_name: Optional[str] = attrs.field( + repr=False, default=None, metadata=docs("Emoji name if custom, unicode character if standard") ) class GuildIntegration(DiscordObject): - name: str = field(repr=True) + name: str = attrs.field(repr=True) """The name of the integration""" - type: str = field(repr=True) + type: str = attrs.field(repr=True) """integration type (twitch, youtube, or discord)""" - enabled: bool = field(repr=True) + enabled: bool = attrs.field(repr=True) """is this integration enabled""" - account: dict = field() + account: dict = attrs.field( + repr=False, + ) """integration account information""" - application: Optional["models.Application"] = field(default=None) + application: Optional["models.Application"] = attrs.field(repr=False, default=None) """The bot/OAuth2 application for discord integrations""" - _guild_id: Snowflake_Type = field() + _guild_id: Snowflake_Type = attrs.field( + repr=False, + ) - syncing: Optional[bool] = field(default=MISSING) + syncing: Optional[bool] = attrs.field(repr=False, default=MISSING) """is this integration syncing""" - role_id: Optional[Snowflake_Type] = field(default=MISSING) + role_id: Optional[Snowflake_Type] = attrs.field(repr=False, default=MISSING) """id that this integration uses for "subscribers\"""" - enable_emoticons: bool = field(default=MISSING) + enable_emoticons: bool = attrs.field(repr=False, default=MISSING) """whether emoticons should be synced for this integration (twitch only currently)""" - expire_behavior: IntegrationExpireBehaviour = field(default=MISSING, converter=optional(IntegrationExpireBehaviour)) + expire_behavior: IntegrationExpireBehaviour = attrs.field( + repr=False, default=MISSING, converter=optional(IntegrationExpireBehaviour) + ) """the behavior of expiring subscribers""" - expire_grace_period: int = field(default=MISSING) + expire_grace_period: int = attrs.field(repr=False, default=MISSING) """the grace period (in days) before expiring subscribers""" - user: "models.BaseUser" = field(default=MISSING) + user: "models.BaseUser" = attrs.field(repr=False, default=MISSING) """user for this integration""" - synced_at: "models.Timestamp" = field(default=MISSING, converter=optional(timestamp_converter)) + synced_at: "models.Timestamp" = attrs.field(repr=False, default=MISSING, converter=optional(timestamp_converter)) """when this integration was last synced""" - subscriber_count: int = field(default=MISSING) + subscriber_count: int = attrs.field(repr=False, default=MISSING) """how many subscribers this integration has""" - revoked: bool = field(default=MISSING) + revoked: bool = attrs.field(repr=False, default=MISSING) """has this integration been revoked""" @classmethod @@ -2087,23 +2137,23 @@ async def delete(self, reason: Absent[str] = MISSING) -> None: class GuildWidgetSettings(DictSerializationMixin): - enabled: bool = field(repr=True, default=False) + enabled: bool = attrs.field(repr=True, default=False) """Whether the widget is enabled.""" - channel_id: Optional["Snowflake_Type"] = field(repr=True, default=None, converter=to_optional_snowflake) + channel_id: Optional["Snowflake_Type"] = attrs.field(repr=True, default=None, converter=to_optional_snowflake) """The widget channel id. None if widget is not enabled.""" class GuildWidget(DiscordObject): - name: str = field(repr=True) + name: str = attrs.field(repr=True) """Guild name (2-100 characters)""" - instant_invite: str = field(repr=True, default=None) + instant_invite: str = attrs.field(repr=True, default=None) """Instant invite for the guilds specified widget invite channel""" - presence_count: int = field(repr=True, default=0) + presence_count: int = attrs.field(repr=True, default=0) """Number of online members in this guild""" - _channel_ids: List["Snowflake_Type"] = field(default=[]) + _channel_ids: List["Snowflake_Type"] = attrs.field(repr=False, default=[]) """Voice and stage channels which are accessible by @everyone""" - _member_ids: List["Snowflake_Type"] = field(default=[]) + _member_ids: List["Snowflake_Type"] = attrs.field(repr=False, default=[]) """Special widget user objects that includes users presence (Limit 100)""" @classmethod @@ -2155,29 +2205,29 @@ async def fetch_members(self) -> List["models.User"]: return [await self._client.fetch_user(member_id) for member_id in self._member_ids] -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AuditLogChange(ClientObject): - key: str = field(repr=True) + key: str = attrs.field(repr=True) """name of audit log change key""" - new_value: Optional[Union[list, str, int, bool, "Snowflake_Type"]] = field(default=MISSING) + new_value: Optional[Union[list, str, int, bool, "Snowflake_Type"]] = attrs.field(repr=False, default=MISSING) """new value of the key""" - old_value: Optional[Union[list, str, int, bool, "Snowflake_Type"]] = field(default=MISSING) + old_value: Optional[Union[list, str, int, bool, "Snowflake_Type"]] = attrs.field(repr=False, default=MISSING) """old value of the key""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AuditLogEntry(DiscordObject): - target_id: Optional["Snowflake_Type"] = field(converter=optional(to_snowflake)) + target_id: Optional["Snowflake_Type"] = attrs.field(repr=False, converter=optional(to_snowflake)) """id of the affected entity (webhook, user, role, etc.)""" - user_id: "Snowflake_Type" = field(converter=optional(to_snowflake)) + user_id: "Snowflake_Type" = attrs.field(repr=False, converter=optional(to_snowflake)) """the user who made the changes""" - action_type: "AuditLogEventType" = field(converter=AuditLogEventType) + action_type: "AuditLogEventType" = attrs.field(repr=False, converter=AuditLogEventType) """type of action that occurred""" - changes: Optional[List[AuditLogChange]] = field(default=MISSING) + changes: Optional[List[AuditLogChange]] = attrs.field(repr=False, default=MISSING) """changes made to the target_id""" - options: Optional[Union["Snowflake_Type", str]] = field(default=MISSING) + options: Optional[Union["Snowflake_Type", str]] = attrs.field(repr=False, default=MISSING) """additional info for certain action types""" - reason: Optional[str] = field(default=MISSING) + reason: Optional[str] = attrs.field(repr=False, default=MISSING) """the reason for the change (0-512 characters)""" @classmethod @@ -2188,23 +2238,25 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AuditLog(ClientObject): """Contains entries and other data given from selected""" - application_commands: list["InteractionCommand"] = field(factory=list, converter=optional(deserialize_app_cmds)) + application_commands: list["InteractionCommand"] = attrs.field( + repr=False, factory=list, converter=optional(deserialize_app_cmds) + ) """list of application commands that have had their permissions updated""" - entries: Optional[List["AuditLogEntry"]] = field(default=MISSING) + entries: Optional[List["AuditLogEntry"]] = attrs.field(repr=False, default=MISSING) """list of audit log entries""" - scheduled_events: Optional[List["models.ScheduledEvent"]] = field(default=MISSING) + scheduled_events: Optional[List["models.ScheduledEvent"]] = attrs.field(repr=False, default=MISSING) """list of guild scheduled events found in the audit log""" - integrations: Optional[List["GuildIntegration"]] = field(default=MISSING) + integrations: Optional[List["GuildIntegration"]] = attrs.field(repr=False, default=MISSING) """list of partial integration objects""" - threads: Optional[List["models.ThreadChannel"]] = field(default=MISSING) + threads: Optional[List["models.ThreadChannel"]] = attrs.field(repr=False, default=MISSING) """list of threads found in the audit log""" - users: Optional[List["models.User"]] = field(default=MISSING) + users: Optional[List["models.User"]] = attrs.field(repr=False, default=MISSING) """list of users found in the audit log""" - webhooks: Optional[List["models.Webhook"]] = field(default=MISSING) + webhooks: Optional[List["models.Webhook"]] = attrs.field(repr=False, default=MISSING) """list of webhooks found in the audit log""" @classmethod diff --git a/naff/models/discord/invite.py b/naff/models/discord/invite.py index a9d8f5810..33758217b 100644 --- a/naff/models/discord/invite.py +++ b/naff/models/discord/invite.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING, Optional, Union, Dict, Any +import attrs + from naff.client.const import MISSING, Absent -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter from naff.models.discord.application import Application @@ -21,47 +22,51 @@ __all__ = ("Invite",) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Invite(ClientObject): - code: str = field(repr=True) + code: str = attrs.field(repr=True) """the invite code (unique ID)""" # metadata - uses: int = field(default=0, repr=True) + uses: int = attrs.field(default=0, repr=True) """the guild this invite is for""" - max_uses: int = field(default=0) + max_uses: int = attrs.field(repr=False, default=0) """max number of times this invite can be used""" - max_age: int = field(default=0) + max_age: int = attrs.field(repr=False, default=0) """duration (in seconds) after which the invite expires""" - created_at: Timestamp = field(default=MISSING, converter=optional_c(timestamp_converter), repr=True) + created_at: Timestamp = attrs.field(default=MISSING, converter=optional_c(timestamp_converter), repr=True) """when this invite was created""" - temporary: bool = field(default=False, repr=True) + temporary: bool = attrs.field(default=False, repr=True) """whether this invite only grants temporary membership""" # target data - target_type: Optional[Union[InviteTargetTypes, int]] = field( + target_type: Optional[Union[InviteTargetTypes, int]] = attrs.field( default=None, converter=optional_c(InviteTargetTypes), repr=True ) """the type of target for this voice channel invite""" - approximate_presence_count: Optional[int] = field(default=MISSING) + approximate_presence_count: Optional[int] = attrs.field(repr=False, default=MISSING) """approximate count of online members, returned from the `GET /invites/` endpoint when `with_counts` is `True`""" - approximate_member_count: Optional[int] = field(default=MISSING) + approximate_member_count: Optional[int] = attrs.field(repr=False, default=MISSING) """approximate count of total members, returned from the `GET /invites/` endpoint when `with_counts` is `True`""" - scheduled_event: Optional["Snowflake_Type"] = field(default=None, converter=optional_c(to_snowflake), repr=True) + scheduled_event: Optional["Snowflake_Type"] = attrs.field( + default=None, converter=optional_c(to_snowflake), repr=True + ) """guild scheduled event data, only included if `guild_scheduled_event_id` contains a valid guild scheduled event id""" - expires_at: Optional[Timestamp] = field(default=None, converter=optional_c(timestamp_converter), repr=True) + expires_at: Optional[Timestamp] = attrs.field(default=None, converter=optional_c(timestamp_converter), repr=True) """the expiration date of this invite, returned from the `GET /invites/` endpoint when `with_expiration` is `True`""" - stage_instance: Optional[StageInstance] = field(default=None) + stage_instance: Optional[StageInstance] = attrs.field(repr=False, default=None) """stage instance data if there is a public Stage instance in the Stage channel this invite is for (deprecated)""" - target_application: Optional[dict] = field(default=None) + target_application: Optional[dict] = attrs.field(repr=False, default=None) """the embedded application to open for this voice channel embedded application invite""" - guild_preview: Optional[GuildPreview] = field(default=MISSING) + guild_preview: Optional[GuildPreview] = attrs.field(repr=False, default=MISSING) """the guild this invite is for""" # internal for props - _channel_id: "Snowflake_Type" = field(converter=to_snowflake, repr=True) - _inviter_id: Optional["Snowflake_Type"] = field(default=None, converter=optional_c(to_snowflake), repr=True) - _target_user_id: Optional["Snowflake_Type"] = field(default=None, converter=optional_c(to_snowflake)) + _channel_id: "Snowflake_Type" = attrs.field(converter=to_snowflake, repr=True) + _inviter_id: Optional["Snowflake_Type"] = attrs.field(default=None, converter=optional_c(to_snowflake), repr=True) + _target_user_id: Optional["Snowflake_Type"] = attrs.field( + repr=False, default=None, converter=optional_c(to_snowflake) + ) @property def channel(self) -> "TYPE_GUILD_CHANNEL": diff --git a/naff/models/discord/message.py b/naff/models/discord/message.py index 52141e91f..5c1075d3c 100644 --- a/naff/models/discord/message.py +++ b/naff/models/discord/message.py @@ -3,14 +3,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Sequence, Union, Mapping +import attrs + import naff.models as models from naff.client.const import GUILD_WELCOME_MESSAGES, MISSING, Absent from naff.client.errors import EphemeralEditException, ThreadOutsideOfGuild from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter from naff.client.utils.serializer import dict_filter_none +from naff.client.utils.text_utils import mentions +from naff.models.discord.channel import BaseChannel from naff.models.discord.file import UPLOADABLE_TYPE from .base import DiscordObject from .enums import ( @@ -23,10 +26,10 @@ AutoArchiveDuration, ) from .snowflake import to_snowflake, Snowflake_Type, to_snowflake_list, to_optional_snowflake -from naff.models.discord.channel import BaseChannel if TYPE_CHECKING: from naff.client import Client + from naff import InteractionContext __all__ = ( "Attachment", @@ -46,25 +49,33 @@ channel_mention = re.compile(r"<#(?P[0-9]{17,})>") -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Attachment(DiscordObject): - filename: str = field() + filename: str = attrs.field( + repr=False, + ) """name of file attached""" - description: Optional[str] = field(default=None) + description: Optional[str] = attrs.field(repr=False, default=None) """description for the file""" - content_type: Optional[str] = field(default=None) + content_type: Optional[str] = attrs.field(repr=False, default=None) """the attachment's media type""" - size: int = field() + size: int = attrs.field( + repr=False, + ) """size of file in bytes""" - url: str = field() + url: str = attrs.field( + repr=False, + ) """source url of file""" - proxy_url: str = field() + proxy_url: str = attrs.field( + repr=False, + ) """a proxied url of file""" - height: Optional[int] = field(default=None) + height: Optional[int] = attrs.field(repr=False, default=None) """height of file (if image)""" - width: Optional[int] = field(default=None) + width: Optional[int] = attrs.field(repr=False, default=None) """width of file (if image)""" - ephemeral: bool = field(default=False) + ephemeral: bool = attrs.field(repr=False, default=False) """whether this attachment is ephemeral""" @property @@ -73,13 +84,17 @@ def resolution(self) -> tuple[Optional[int], Optional[int]]: return self.height, self.width -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ChannelMention(DiscordObject): - guild_id: "Snowflake_Type" = field() + guild_id: "Snowflake_Type" = attrs.field( + repr=False, + ) """id of the guild containing the channel""" - type: ChannelTypes = field(converter=ChannelTypes) + type: ChannelTypes = attrs.field(repr=False, converter=ChannelTypes) """the type of channel""" - name: str = field() + name: str = attrs.field( + repr=False, + ) """the name of the channel""" @@ -91,7 +106,7 @@ class MessageActivity: """party_id from a Rich Presence event""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class MessageReference(DictSerializationMixin): """ Reference to an originating message. @@ -100,13 +115,13 @@ class MessageReference(DictSerializationMixin): """ - message_id: int = field(default=None, converter=optional_c(to_snowflake)) + message_id: int = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the originating message.""" - channel_id: Optional[int] = field(default=None, converter=optional_c(to_snowflake)) + channel_id: Optional[int] = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the originating message's channel.""" - guild_id: Optional[int] = field(default=None, converter=optional_c(to_snowflake)) + guild_id: Optional[int] = attrs.field(repr=False, default=None, converter=optional_c(to_snowflake)) """id of the originating message's guild.""" - fail_if_not_exists: bool = field(default=True) + fail_if_not_exists: bool = attrs.field(repr=False, default=True) """When sending a message, whether to error if the referenced message doesn't exist instead of sending as a normal (non-reply) message, default true.""" @classmethod @@ -130,14 +145,18 @@ def for_message(cls, message: "Message", fail_if_not_exists: bool = True) -> "Me ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class MessageInteraction(DiscordObject): - type: InteractionTypes = field(converter=InteractionTypes) + type: InteractionTypes = attrs.field(repr=False, converter=InteractionTypes) """the type of interaction""" - name: str = field() + name: str = attrs.field( + repr=False, + ) """the name of the application command""" - _user_id: "Snowflake_Type" = field() + _user_id: "Snowflake_Type" = attrs.field( + repr=False, + ) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -150,7 +169,7 @@ async def user(self) -> "models.User": return await self.get_user(self._user_id) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class AllowedMentions(DictSerializationMixin): """ The allowed mention field allows for more granular control over mentions without various hacks to the message content. @@ -160,13 +179,13 @@ class AllowedMentions(DictSerializationMixin): """ - parse: Optional[List[str]] = field(factory=list) + parse: Optional[List[str]] = attrs.field(repr=False, factory=list) """An array of allowed mention types to parse from the content.""" - roles: Optional[List["Snowflake_Type"]] = field(factory=list, converter=to_snowflake_list) + roles: Optional[List["Snowflake_Type"]] = attrs.field(repr=False, factory=list, converter=to_snowflake_list) """Array of role_ids to mention. (Max size of 100)""" - users: Optional[List["Snowflake_Type"]] = field(factory=list, converter=to_snowflake_list) + users: Optional[List["Snowflake_Type"]] = attrs.field(repr=False, factory=list, converter=to_snowflake_list) """Array of user_ids to mention. (Max size of 100)""" - replied_user = field(default=False) + replied_user = attrs.field(repr=False, default=False) """For replies, whether to mention the author of the message being replied to. (default false)""" def add_parse(self, *mention_types: Union["MentionTypes", str]) -> None: @@ -227,12 +246,14 @@ def none(cls) -> "AllowedMentions": return cls() -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseMessage(DiscordObject): - _channel_id: "Snowflake_Type" = field(default=MISSING, converter=to_optional_snowflake) - _thread_channel_id: Optional["Snowflake_Type"] = field(default=None, converter=to_optional_snowflake) - _guild_id: Optional["Snowflake_Type"] = field(default=None, converter=to_optional_snowflake) - _author_id: "Snowflake_Type" = field(default=MISSING, converter=to_optional_snowflake) + _channel_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING, converter=to_optional_snowflake) + _thread_channel_id: Optional["Snowflake_Type"] = attrs.field( + repr=False, default=None, converter=to_optional_snowflake + ) + _guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=to_optional_snowflake) + _author_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING, converter=to_optional_snowflake) @property def guild(self) -> "models.Guild": @@ -267,55 +288,57 @@ def author(self) -> Union["models.Member", "models.User"]: return MISSING -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Message(BaseMessage): - content: str = field(default=MISSING) + content: str = attrs.field(repr=False, default=MISSING) """Contents of the message""" - timestamp: "models.Timestamp" = field(default=MISSING, converter=optional_c(timestamp_converter)) + timestamp: "models.Timestamp" = attrs.field(repr=False, default=MISSING, converter=optional_c(timestamp_converter)) """When this message was sent""" - edited_timestamp: Optional["models.Timestamp"] = field(default=None, converter=optional_c(timestamp_converter)) + edited_timestamp: Optional["models.Timestamp"] = attrs.field( + repr=False, default=None, converter=optional_c(timestamp_converter) + ) """When this message was edited (or `None` if never)""" - tts: bool = field(default=False) + tts: bool = attrs.field(repr=False, default=False) """Whether this was a TTS message""" - mention_everyone: bool = field(default=False) + mention_everyone: bool = attrs.field(repr=False, default=False) """Whether this message mentions everyone""" - mention_channels: List[ChannelMention] = field(factory=list) + mention_channels: List[ChannelMention] = attrs.field(repr=False, factory=list) """Channels specifically mentioned in this message""" - attachments: List[Attachment] = field(factory=list) + attachments: List[Attachment] = attrs.field(repr=False, factory=list) """Any attached files""" - embeds: List["models.Embed"] = field(factory=list) + embeds: List["models.Embed"] = attrs.field(repr=False, factory=list) """Any embedded content""" - reactions: List["models.Reaction"] = field(factory=list) + reactions: List["models.Reaction"] = attrs.field(repr=False, factory=list) """Reactions to the message""" - nonce: Optional[Union[int, str]] = field(default=None) + nonce: Optional[Union[int, str]] = attrs.field(repr=False, default=None) """Used for validating a message was sent""" - pinned: bool = field(default=False) + pinned: bool = attrs.field(repr=False, default=False) """Whether this message is pinned""" - webhook_id: Optional["Snowflake_Type"] = field(default=None, converter=to_optional_snowflake) + webhook_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=to_optional_snowflake) """If the message is generated by a webhook, this is the webhook's id""" - type: MessageTypes = field(default=MISSING, converter=optional_c(MessageTypes)) + type: MessageTypes = attrs.field(repr=False, default=MISSING, converter=optional_c(MessageTypes)) """Type of message""" - activity: Optional[MessageActivity] = field(default=None, converter=optional_c(MessageActivity)) + activity: Optional[MessageActivity] = attrs.field(repr=False, default=None, converter=optional_c(MessageActivity)) """Activity sent with Rich Presence-related chat embeds""" - application: Optional["models.Application"] = field(default=None) # TODO: partial application + application: Optional["models.Application"] = attrs.field(repr=False, default=None) # TODO: partial application """Application sent with Rich Presence-related chat embeds""" - application_id: Optional["Snowflake_Type"] = field(default=None, converter=to_optional_snowflake) + application_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=to_optional_snowflake) """If the message is an Interaction or application-owned webhook, this is the id of the application""" - message_reference: Optional[MessageReference] = field( - default=None, converter=optional_c(MessageReference.from_dict) + message_reference: Optional[MessageReference] = attrs.field( + repr=False, default=None, converter=optional_c(MessageReference.from_dict) ) """Data showing the source of a crosspost, channel follow add, pin, or reply message""" - flags: MessageFlags = field(default=MessageFlags.NONE, converter=MessageFlags) + flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags) """Message flags combined as a bitfield""" - interaction: Optional["MessageInteraction"] = field(default=None) + interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None) """Sent if the message is a response to an Interaction""" - components: Optional[List["models.ActionRow"]] = field(default=None) + components: Optional[List["models.ActionRow"]] = attrs.field(repr=False, default=None) """Sent if the message contains components like buttons, action rows, or other interactive components""" - sticker_items: Optional[List["models.StickerItem"]] = field(default=None) + sticker_items: Optional[List["models.StickerItem"]] = attrs.field(repr=False, default=None) """Sent if the message contains stickers""" - _mention_ids: List["Snowflake_Type"] = field(factory=list) - _mention_roles: List["Snowflake_Type"] = field(factory=list) - _referenced_message_id: Optional["Snowflake_Type"] = field(default=None) + _mention_ids: List["Snowflake_Type"] = attrs.field(repr=False, factory=list) + _mention_roles: List["Snowflake_Type"] = attrs.field(repr=False, factory=list) + _referenced_message_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) @property async def mention_users(self) -> AsyncGenerator["models.Member", None]: @@ -358,6 +381,24 @@ def get_referenced_message(self) -> Optional["Message"]: return None return self._client.cache.get_message(self._channel_id, self._referenced_message_id) + def contains_mention( + self, + query: "str | re.Pattern[str] | models.BaseUser | models.BaseChannel | models.Role", + *, + tag_as_mention: bool = False, + ) -> bool: + """ + Check whether the message contains the query or not. + + Args: + query: The query to search for + tag_as_mention: Should `BaseUser.tag` be checked *(only if query is an instance of BaseUser)* + + Returns: + A boolean indicating whether the query could be found or not + """ + return mentions(text=self.content or self.system_content, query=query, tag_as_mention=tag_as_mention) + @classmethod def _process_dict(cls, data: dict, client: "Client") -> dict: if author_data := data.pop("author", None): @@ -503,6 +544,7 @@ def proto_url(self) -> str: async def edit( self, + *, content: Optional[str] = None, embeds: Optional[Union[Sequence[Union["models.Embed", dict]], Union["models.Embed", dict]]] = None, embed: Optional[Union["models.Embed", dict]] = None, @@ -520,6 +562,7 @@ async def edit( file: Optional[UPLOADABLE_TYPE] = None, tts: bool = False, flags: Optional[Union[int, MessageFlags]] = None, + context: "InteractionContext | None" = None, ) -> "Message": """ Edits the message. @@ -535,49 +578,72 @@ async def edit( file: Files to send, the path, bytes or File() instance, defaults to None. You may have up to 10 files. tts: Should this message use Text To Speech. flags: Message flags to apply. + context: The interaction context to use for the edit Returns: New message object with edits applied """ - message_payload = process_message_payload( - content=content, - embeds=embeds or embed, - components=components, - allowed_mentions=allowed_mentions, - attachments=attachments, - tts=tts, - flags=flags, - ) + if context: + return await context.edit( + self, + content=content, + embeds=embeds, + embed=embed, + components=components, + allowed_mentions=allowed_mentions, + attachments=attachments, + files=files, + file=file, + tts=tts, + ) + else: + if self.flags == MessageFlags.EPHEMERAL: + raise EphemeralEditException + message_payload = process_message_payload( + content=content, + embeds=embeds or embed, + components=components, + allowed_mentions=allowed_mentions, + attachments=attachments, + tts=tts, + flags=flags, + ) + if file: + if files: + files = [file, *files] + else: + files = [file] - if self.flags == MessageFlags.EPHEMERAL: - raise EphemeralEditException + message_data = await self._client.http.edit_message(message_payload, self._channel_id, self.id, files=files) + if message_data: + return self._client.cache.place_message_data(message_data) - message_data = await self._client.http.edit_message(message_payload, self._channel_id, self.id, files=files) - if message_data: - return self._client.cache.place_message_data(message_data) - - async def delete(self, delay: Absent[Optional[int]] = MISSING) -> None: + async def delete(self, delay: int = 0, *, context: "InteractionContext | None" = None) -> None: """ Delete message. Args: delay: Seconds to wait before deleting message. + context: An optional interaction context to delete ephemeral messages. """ - if delay and delay > 0: - async def delayed_delete() -> None: + async def _delete() -> None: + if delay: await asyncio.sleep(delay) - try: - await self._client.http.delete_message(self._channel_id, self.id) - except Exception: # noqa: S110 - pass # No real way to handle this - asyncio.create_task(delayed_delete()) + if MessageFlags.EPHEMERAL in self.flags: + if not context: + raise ValueError("Cannot delete ephemeral message without interaction context parameter") + await context.delete(self.id) + else: + await self._client.http.delete_message(self._channel_id, self.id) + if delay: + asyncio.create_task(_delete()) else: - await self._client.http.delete_message(self._channel_id, self.id) + return await _delete() async def reply( self, diff --git a/naff/models/discord/modal.py b/naff/models/discord/modal.py index 290a71ad2..e7a1f770b 100644 --- a/naff/models/discord/modal.py +++ b/naff/models/discord/modal.py @@ -6,9 +6,9 @@ from naff.client.const import MISSING from naff.client.mixins.serialization import DictSerializationMixin -from naff.models.naff.application_commands import CallbackTypes +from naff.client.utils.attr_utils import str_validator from naff.models.discord.components import InteractiveComponent, ComponentTypes -from naff.client.utils.attr_utils import define, field, str_validator +from naff.models.naff.application_commands import CallbackTypes __all__ = ("InputText", "Modal", "ParagraphText", "ShortText", "TextStyles") @@ -18,59 +18,65 @@ class TextStyles(IntEnum): PARAGRAPH = 2 -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class InputText(InteractiveComponent): """An input component for modals""" - type: Union[ComponentTypes, int] = field( - default=ComponentTypes.INPUT_TEXT, init=False, on_setattr=attrs.setters.frozen + type: Union[ComponentTypes, int] = attrs.field( + repr=False, default=ComponentTypes.INPUT_TEXT, init=False, on_setattr=attrs.setters.frozen ) - label: str = field(validator=str_validator) + label: str = attrs.field(repr=False, validator=str_validator) """the label for this component""" - style: Union[TextStyles, int] = field() + style: Union[TextStyles, int] = attrs.field( + repr=False, + ) """the Text Input Style for single or multiple lines input""" - custom_id: Optional[str] = field(factory=lambda: str(uuid.uuid4()), validator=str_validator) + custom_id: Optional[str] = attrs.field(repr=False, factory=lambda: str(uuid.uuid4()), validator=str_validator) """a developer-defined identifier for the input, max 100 characters""" - placeholder: Optional[str] = field(default=MISSING, validator=str_validator, kw_only=True) + placeholder: Optional[str] = attrs.field(repr=False, default=MISSING, validator=str_validator, kw_only=True) """custom placeholder text if the input is empty, max 100 characters""" - value: Optional[str] = field(default=MISSING, validator=str_validator, kw_only=True) + value: Optional[str] = attrs.field(repr=False, default=MISSING, validator=str_validator, kw_only=True) """a pre-filled value for this component, max 4000 characters""" - required: bool = field(default=True, kw_only=True) + required: bool = attrs.field(repr=False, default=True, kw_only=True) """whether this component is required to be filled, default true""" - min_length: Optional[int] = field(default=MISSING, kw_only=True) + min_length: Optional[int] = attrs.field(repr=False, default=MISSING, kw_only=True) """the minimum input length for a text input, min 0, max 4000""" - max_length: Optional[int] = field(default=MISSING, kw_only=True) + max_length: Optional[int] = attrs.field(repr=False, default=MISSING, kw_only=True) """the maximum input length for a text input, min 1, max 4000. Must be more than min_length.""" -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ShortText(InputText): """A single line input component for modals""" - style: Union[TextStyles, int] = field(default=TextStyles.SHORT, kw_only=True) + style: Union[TextStyles, int] = attrs.field(repr=False, default=TextStyles.SHORT, kw_only=True) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class ParagraphText(InputText): """A multi line input component for modals""" - style: Union[TextStyles, int] = field(default=TextStyles.PARAGRAPH, kw_only=True) + style: Union[TextStyles, int] = attrs.field(repr=False, default=TextStyles.PARAGRAPH, kw_only=True) -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class Modal(DictSerializationMixin): """Form submission style component on discord""" - type: Union[CallbackTypes, int] = field(default=CallbackTypes.MODAL, init=False, on_setattr=attrs.setters.frozen) - title: str = field(validator=str_validator) + type: Union[CallbackTypes, int] = attrs.field( + repr=False, default=CallbackTypes.MODAL, init=False, on_setattr=attrs.setters.frozen + ) + title: str = attrs.field(repr=False, validator=str_validator) """the title of the popup modal, max 45 characters""" - components: List[InputText] = field() + components: List[InputText] = attrs.field( + repr=False, + ) """between 1 and 5 (inclusive) components that make up the modal""" - custom_id: Optional[str] = field(factory=lambda: str(uuid.uuid4()), validator=str_validator) + custom_id: Optional[str] = attrs.field(repr=False, factory=lambda: str(uuid.uuid4()), validator=str_validator) """a developer-defined identifier for the component, max 100 characters""" def __attrs_post_init__(self) -> None: diff --git a/naff/models/discord/reaction.py b/naff/models/discord/reaction.py index 98e3b1639..d3e69c4ed 100644 --- a/naff/models/discord/reaction.py +++ b/naff/models/discord/reaction.py @@ -2,8 +2,9 @@ from collections import namedtuple from typing import TYPE_CHECKING, List, Optional +import attrs + from naff.client.const import MISSING -from naff.client.utils.attr_utils import define, field from naff.models.discord.emoji import PartialEmoji from naff.models.discord.snowflake import to_snowflake from naff.models.misc.iterator import AsyncIterator @@ -64,17 +65,19 @@ async def fetch(self) -> List["User"]: raise QueueEmpty -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Reaction(ClientObject): - count: int = field() + count: int = attrs.field( + repr=False, + ) """times this emoji has been used to react""" - me: bool = field(default=False) + me: bool = attrs.field(repr=False, default=False) """whether the current user reacted using this emoji""" - emoji: "PartialEmoji" = field(converter=PartialEmoji.from_dict) + emoji: "PartialEmoji" = attrs.field(repr=False, converter=PartialEmoji.from_dict) """emoji information""" - _channel_id: "Snowflake_Type" = field(converter=to_snowflake) - _message_id: "Snowflake_Type" = field(converter=to_snowflake) + _channel_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) + _message_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) def users(self, limit: int = 0, after: "Snowflake_Type" = None) -> ReactionUsers: """Users who reacted using this emoji.""" diff --git a/naff/models/discord/role.py b/naff/models/discord/role.py index 2781db3cf..8f9458a64 100644 --- a/naff/models/discord/role.py +++ b/naff/models/discord/role.py @@ -4,12 +4,11 @@ import attrs from naff.client.const import MISSING, T, Missing -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.serializer import dict_filter from naff.models.discord.asset import Asset -from naff.models.discord.emoji import PartialEmoji from naff.models.discord.color import COLOR_TYPES, Color, process_color +from naff.models.discord.emoji import PartialEmoji from naff.models.discord.enums import Permissions from .base import DiscordObject @@ -30,24 +29,30 @@ def sentinel_converter(value: bool | T | None, sentinel: T = attrs.NOTHING) -> b return value -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) @total_ordering class Role(DiscordObject): _sentinel = object() - name: str = field(repr=True) - color: "Color" = field(converter=Color) - hoist: bool = field(default=False) - position: int = field(repr=True) - permissions: "Permissions" = field(converter=Permissions) - managed: bool = field(default=False) - mentionable: bool = field(default=True) - premium_subscriber: bool = field(default=_sentinel, converter=partial(sentinel_converter, sentinel=_sentinel)) - _icon: Asset | None = field(default=None) - _unicode_emoji: PartialEmoji | None = field(default=None, converter=optional_c(PartialEmoji.from_str)) - _guild_id: "Snowflake_Type" = field() - _bot_id: "Snowflake_Type | None" = field(default=None) - _integration_id: "Snowflake_Type | None" = field(default=None) # todo integration object? + name: str = attrs.field(repr=True) + color: "Color" = attrs.field(repr=False, converter=Color) + hoist: bool = attrs.field(repr=False, default=False) + position: int = attrs.field(repr=True) + permissions: "Permissions" = attrs.field(repr=False, converter=Permissions) + managed: bool = attrs.field(repr=False, default=False) + mentionable: bool = attrs.field(repr=False, default=True) + premium_subscriber: bool = attrs.field( + repr=False, default=_sentinel, converter=partial(sentinel_converter, sentinel=_sentinel) + ) + _icon: Asset | None = attrs.field(repr=False, default=None) + _unicode_emoji: PartialEmoji | None = attrs.field( + repr=False, default=None, converter=optional_c(PartialEmoji.from_str) + ) + _guild_id: "Snowflake_Type" = attrs.field( + repr=False, + ) + _bot_id: "Snowflake_Type | None" = attrs.field(repr=False, default=None) + _integration_id: "Snowflake_Type | None" = attrs.field(repr=False, default=None) # todo integration object? def __lt__(self: "Role", other: "Role") -> bool: if not isinstance(self, Role) or not isinstance(other, Role): @@ -167,6 +172,7 @@ async def delete(self, reason: str | Missing = MISSING) -> None: async def edit( self, + *, name: str | None = None, permissions: str | None = None, color: Color | COLOR_TYPES | None = None, diff --git a/naff/models/discord/scheduled_event.py b/naff/models/discord/scheduled_event.py index 9fec8b7e2..cdcb1a23e 100644 --- a/naff/models/discord/scheduled_event.py +++ b/naff/models/discord/scheduled_event.py @@ -1,17 +1,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import attrs + from naff.client.const import MISSING, Absent from naff.client.errors import EventLocationNotProvided -from naff.client.utils.attr_utils import define, field +from naff.client.utils import to_image_data from naff.client.utils.attr_converters import optional from naff.client.utils.attr_converters import timestamp_converter +from naff.models.discord.asset import Asset +from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.discord.snowflake import Snowflake_Type, to_snowflake from naff.models.discord.timestamp import Timestamp from .base import DiscordObject from .enums import ScheduledEventPrivacyLevel, ScheduledEventType, ScheduledEventStatus -from naff.models.discord.asset import Asset -from naff.models.discord.file import UPLOADABLE_TYPE -from naff.client.utils import to_image_data if TYPE_CHECKING: from naff.client import Client @@ -23,38 +24,40 @@ __all__ = ("ScheduledEvent",) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ScheduledEvent(DiscordObject): - name: str = field(repr=True) - description: str = field(default=MISSING) - entity_type: Union[ScheduledEventType, int] = field(converter=ScheduledEventType) + name: str = attrs.field(repr=True) + description: str = attrs.field(repr=False, default=MISSING) + entity_type: Union[ScheduledEventType, int] = attrs.field(repr=False, converter=ScheduledEventType) """The type of the scheduled event""" - start_time: Timestamp = field(converter=timestamp_converter) + start_time: Timestamp = attrs.field(repr=False, converter=timestamp_converter) """A Timestamp object representing the scheduled start time of the event """ - end_time: Optional[Timestamp] = field(default=None, converter=optional(timestamp_converter)) + end_time: Optional[Timestamp] = attrs.field(repr=False, default=None, converter=optional(timestamp_converter)) """Optional Timstamp object representing the scheduled end time, required if entity_type is EXTERNAL""" - privacy_level: Union[ScheduledEventPrivacyLevel, int] = field(converter=ScheduledEventPrivacyLevel) + privacy_level: Union[ScheduledEventPrivacyLevel, int] = attrs.field( + repr=False, converter=ScheduledEventPrivacyLevel + ) """ Privacy level of the scheduled event ??? note Discord only has `GUILD_ONLY` at the momment. """ - status: Union[ScheduledEventStatus, int] = field(converter=ScheduledEventStatus) + status: Union[ScheduledEventStatus, int] = attrs.field(repr=False, converter=ScheduledEventStatus) """Current status of the scheduled event""" - entity_id: Optional["Snowflake_Type"] = field(default=MISSING, converter=optional(to_snowflake)) + entity_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=MISSING, converter=optional(to_snowflake)) """The id of an entity associated with a guild scheduled event""" - entity_metadata: Optional[Dict[str, Any]] = field(default=MISSING) # TODO make this + entity_metadata: Optional[Dict[str, Any]] = attrs.field(repr=False, default=MISSING) # TODO make this """The metadata associated with the entity_type""" - user_count: int = field(default=MISSING) + user_count: int = attrs.field(repr=False, default=MISSING) """Amount of users subscribed to the scheduled event""" - cover: Asset | None = field(default=None) + cover: Asset | None = attrs.field(repr=False, default=None) """The cover image of this event""" - _guild_id: "Snowflake_Type" = field(converter=to_snowflake) - _creator: Optional["User"] = field(default=MISSING) - _creator_id: Optional["Snowflake_Type"] = field(default=MISSING, converter=optional(to_snowflake)) - _channel_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) + _guild_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) + _creator: Optional["User"] = attrs.field(repr=False, default=MISSING) + _creator_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=MISSING, converter=optional(to_snowflake)) + _channel_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) @property async def creator(self) -> Optional["User"]: @@ -157,6 +160,7 @@ async def delete(self, reason: Absent[str] = MISSING) -> None: async def edit( self, + *, name: Absent[str] = MISSING, start_time: Absent["Timestamp"] = MISSING, end_time: Absent["Timestamp"] = MISSING, diff --git a/naff/models/discord/snowflake.py b/naff/models/discord/snowflake.py index 6b9bc6c3e..fd0a3beb1 100644 --- a/naff/models/discord/snowflake.py +++ b/naff/models/discord/snowflake.py @@ -1,8 +1,9 @@ from typing import Union, List, SupportsInt, Optional +import attrs + import naff.models as models from naff.client.const import MISSING, Absent -from naff.client.utils.attr_utils import define, field __all__ = ("to_snowflake", "to_optional_snowflake", "to_snowflake_list", "SnowflakeObject", "Snowflake_Type") @@ -52,9 +53,9 @@ def to_snowflake_list(snowflakes: List[Snowflake_Type]) -> List[int]: return [to_snowflake(c) for c in snowflakes] -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class SnowflakeObject: - id: int = field(repr=True, converter=to_snowflake, metadata={"docs": "Discord unique snowflake ID"}) + id: int = attrs.field(repr=True, converter=to_snowflake, metadata={"docs": "Discord unique snowflake ID"}) def __eq__(self, other: "SnowflakeObject") -> bool: if hasattr(other, "id"): diff --git a/naff/models/discord/stage_instance.py b/naff/models/discord/stage_instance.py index 696349da8..296b11844 100644 --- a/naff/models/discord/stage_instance.py +++ b/naff/models/discord/stage_instance.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING, Optional +import attrs + from naff.client.const import MISSING, Absent -from naff.client.utils.attr_utils import define, field from naff.models.discord.enums import StagePrivacyLevel from naff.models.discord.snowflake import to_snowflake from .base import DiscordObject @@ -12,14 +13,20 @@ __all__ = ("StageInstance",) -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class StageInstance(DiscordObject): - topic: str = field() - privacy_level: StagePrivacyLevel = field() - discoverable_disabled: bool = field() - - _guild_id: "Snowflake_Type" = field(converter=to_snowflake) - _channel_id: "Snowflake_Type" = field(converter=to_snowflake) + topic: str = attrs.field( + repr=False, + ) + privacy_level: StagePrivacyLevel = attrs.field( + repr=False, + ) + discoverable_disabled: bool = attrs.field( + repr=False, + ) + + _guild_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) + _channel_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) @property def guild(self) -> "Guild": diff --git a/naff/models/discord/sticker.py b/naff/models/discord/sticker.py index ebde3a325..30355e822 100644 --- a/naff/models/discord/sticker.py +++ b/naff/models/discord/sticker.py @@ -1,8 +1,9 @@ from enum import IntEnum from typing import TYPE_CHECKING, List, Optional, Union +import attrs + from naff.client.const import MISSING, Absent -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional from naff.client.utils.serializer import dict_filter_none from naff.models.discord.snowflake import to_snowflake @@ -33,33 +34,33 @@ class StickerFormatTypes(IntEnum): LOTTIE = 3 -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class StickerItem(DiscordObject): - name: str = field(repr=True) + name: str = attrs.field(repr=True) """Name of the sticker.""" - format_type: StickerFormatTypes = field(repr=True, converter=StickerFormatTypes) + format_type: StickerFormatTypes = attrs.field(repr=True, converter=StickerFormatTypes) """Type of sticker image format.""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Sticker(StickerItem): """Represents a sticker that can be sent in messages.""" - pack_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) + pack_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) """For standard stickers, id of the pack the sticker is from.""" - description: Optional[str] = field(default=None) + description: Optional[str] = attrs.field(repr=False, default=None) """Description of the sticker.""" - tags: str = field() + tags: str = attrs.field(repr=False) """autocomplete/suggestion tags for the sticker (max 200 characters)""" - type: Union[StickerTypes, int] = field(converter=StickerTypes) + type: Union[StickerTypes, int] = attrs.field(repr=False, converter=StickerTypes) """Type of sticker.""" - available: Optional[bool] = field(default=True) + available: Optional[bool] = attrs.field(repr=False, default=True) """Whether this guild sticker can be used, may be false due to loss of Server Boosts.""" - sort_value: Optional[int] = field(default=None) + sort_value: Optional[int] = attrs.field(repr=False, default=None) """The standard sticker's sort order within its pack.""" - _user_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) - _guild_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) + _user_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) + _guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) async def fetch_creator(self) -> "User": """ @@ -103,6 +104,7 @@ def get_guild(self) -> "Guild": async def edit( self, + *, name: Absent[Optional[str]] = MISSING, description: Absent[Optional[str]] = MISSING, tags: Absent[Optional[str]] = MISSING, @@ -145,19 +147,19 @@ async def delete(self, reason: Optional[str] = MISSING) -> None: await self._client.http.delete_guild_sticker(self._guild_id, self.id, reason) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class StickerPack(DiscordObject): """Represents a pack of standard stickers.""" - stickers: List["Sticker"] = field(factory=list) + stickers: List["Sticker"] = attrs.field(repr=False, factory=list) """The stickers in the pack.""" - name: str = field(repr=True) + name: str = attrs.field(repr=True) """Name of the sticker pack.""" - sku_id: "Snowflake_Type" = field(repr=True) + sku_id: "Snowflake_Type" = attrs.field(repr=True) """id of the pack's SKU.""" - cover_sticker_id: Optional["Snowflake_Type"] = field(default=None) + cover_sticker_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """id of a sticker in the pack which is shown as the pack's icon.""" - description: str = field() + description: str = attrs.field(repr=False) """Description of the sticker pack.""" - banner_asset_id: "Snowflake_Type" = field() # TODO CDN Asset + banner_asset_id: "Snowflake_Type" = attrs.field(repr=False) # TODO CDN Asset """id of the sticker pack's banner image.""" diff --git a/naff/models/discord/team.py b/naff/models/discord/team.py index 62ae9a57e..daf50a74f 100644 --- a/naff/models/discord/team.py +++ b/naff/models/discord/team.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Dict, Any, Union -from naff.client.utils.attr_utils import define, field +import attrs + from naff.models.discord.asset import Asset from naff.models.discord.enums import TeamMembershipState from naff.models.discord.snowflake import to_snowflake @@ -14,14 +15,16 @@ __all__ = ("TeamMember", "Team") -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class TeamMember(DiscordObject): - membership_state: TeamMembershipState = field(converter=TeamMembershipState) + membership_state: TeamMembershipState = attrs.field(repr=False, converter=TeamMembershipState) """Rhe user's membership state on the team""" - # permissions: List[str] = field(default=["*"]) # disabled until discord adds more team roles - team_id: "Snowflake_Type" = field(repr=True) + # permissions: List[str] = attrs.field(repr=False, default=["*"]) # disabled until discord adds more team roles + team_id: "Snowflake_Type" = attrs.field(repr=True) """Rhe id of the parent team of which they are a member""" - user: "User" = field() # TODO: cache partial user (avatar, discrim, id, username) + user: "User" = attrs.field( + repr=False, + ) # TODO: cache partial user (avatar, discrim, id, username) """Rhe avatar, discriminator, id, and username of the user""" @classmethod @@ -31,15 +34,15 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Team(DiscordObject): - icon: Optional[Asset] = field(default=None) + icon: Optional[Asset] = attrs.field(repr=False, default=None) """A hash of the image of the team's icon""" - members: List[TeamMember] = field(factory=list) + members: List[TeamMember] = attrs.field(repr=False, factory=list) """The members of the team""" - name: str = field(repr=True) + name: str = attrs.field(repr=True) """The name of the team""" - owner_user_id: "Snowflake_Type" = field(converter=to_snowflake) + owner_user_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) """The user id of the current team owner""" @classmethod diff --git a/naff/models/discord/thread.py b/naff/models/discord/thread.py index 33eaf0a7c..e56332c9a 100644 --- a/naff/models/discord/thread.py +++ b/naff/models/discord/thread.py @@ -1,11 +1,12 @@ from typing import TYPE_CHECKING, List, Dict, Any, Union, Optional +import attrs + import naff.models as models from naff.client.const import MISSING from naff.client.mixins.send import SendMixin from naff.client.utils.attr_converters import optional from naff.client.utils.attr_converters import timestamp_converter -from naff.client.utils.attr_utils import define, field from naff.models.discord.emoji import PartialEmoji from naff.models.discord.snowflake import to_snowflake from naff.models.discord.timestamp import Timestamp @@ -27,16 +28,18 @@ ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ThreadMember(DiscordObject, SendMixin): """A thread member is used to indicate whether a user has joined a thread or not.""" - join_timestamp: Timestamp = field(converter=timestamp_converter) + join_timestamp: Timestamp = attrs.field(repr=False, converter=timestamp_converter) """The time the current user last joined the thread.""" - flags: int = field() + flags: int = attrs.field( + repr=False, + ) """Any user-thread settings, currently only used for notifications.""" - _user_id: "Snowflake_Type" = field(converter=optional(to_snowflake)) + _user_id: "Snowflake_Type" = attrs.field(repr=False, converter=optional(to_snowflake)) async def fetch_thread(self) -> "TYPE_THREAD_CHANNEL": """ @@ -85,15 +88,17 @@ async def _send_http_request( return await self._client.http.create_message(message_payload, dm_id, files=files) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ThreadList(ClientObject): """Represents a list of one or more threads.""" - threads: List["TYPE_THREAD_CHANNEL"] = field(factory=list) # TODO Reference the cache or store actual object? + threads: List["TYPE_THREAD_CHANNEL"] = attrs.field( + repr=False, factory=list + ) # TODO Reference the cache or store actual object? """The active threads.""" - members: List[ThreadMember] = field(factory=list) + members: List[ThreadMember] = attrs.field(repr=False, factory=list) """A thread member object for each returned thread the current user has joined.""" - has_more: bool = field(default=False) + has_more: bool = attrs.field(repr=False, default=False) """Whether there are potentially additional threads that could be returned on a subsequent call.""" @classmethod @@ -108,13 +113,15 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ThreadTag(DiscordObject): - name: str = field() - emoji_id: "Snowflake_Type" = field(default=None) - emoji_name: str | None = field(default=None) + name: str = attrs.field( + repr=False, + ) + emoji_id: "Snowflake_Type" = attrs.field(repr=False, default=None) + emoji_name: str | None = attrs.field(repr=False, default=None) - _parent_channel_id: "Snowflake_Type" = field(default=MISSING) + _parent_channel_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING) @property def parent_channel(self) -> "GuildForum": diff --git a/naff/models/discord/timestamp.py b/naff/models/discord/timestamp.py index 6230e07ed..359b120f5 100644 --- a/naff/models/discord/timestamp.py +++ b/naff/models/discord/timestamp.py @@ -20,6 +20,9 @@ class TimestampStyles(str, Enum): LongDateTime = "F" RelativeTime = "R" + def __str__(self) -> str: + return self.value + class Timestamp(datetime): """ diff --git a/naff/models/discord/user.py b/naff/models/discord/user.py index f92d33f46..a308376a1 100644 --- a/naff/models/discord/user.py +++ b/naff/models/discord/user.py @@ -2,13 +2,15 @@ from typing import TYPE_CHECKING, Any, Iterable, Set, Dict, List, Optional, Union from warnings import warn -from naff.client.const import MISSING, logger, Absent +import attrs + +from naff.client.const import Absent, MISSING from naff.client.errors import HTTPException, TooManyChanges from naff.client.mixins.send import SendMixin -from naff.client.utils.attr_utils import define, field, docs from naff.client.utils.attr_converters import list_converter, optional from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter +from naff.client.utils.attr_utils import docs from naff.client.utils.serializer import to_image_data from naff.models.discord.activity import Activity from naff.models.discord.asset import Asset @@ -41,13 +43,13 @@ async def _send_http_request( return await self._client.http.create_message(message_payload, dm_id, files=files) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseUser(DiscordObject, _SendDMMixin): """Base class for User, essentially partial user discord model.""" - username: str = field(repr=True, metadata=docs("The user's username, not unique across the platform")) - discriminator: int = field(repr=True, metadata=docs("The user's 4-digit discord-tag")) - avatar: "Asset" = field(metadata=docs("The user's default avatar")) + username: str = attrs.field(repr=True, metadata=docs("The user's username, not unique across the platform")) + discriminator: int = attrs.field(repr=True, metadata=docs("The user's 4-digit discord-tag")) + avatar: "Asset" = attrs.field(repr=False, metadata=docs("The user's default avatar")) def __str__(self) -> str: return self.tag @@ -102,32 +104,37 @@ def mutual_guilds(self) -> List["Guild"]: ] -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class User(BaseUser): - bot: bool = field(repr=True, default=False, metadata=docs("Is this user a bot?")) - system: bool = field( + bot: bool = attrs.field(repr=True, default=False, metadata=docs("Is this user a bot?")) + system: bool = attrs.field( default=False, metadata=docs("whether the user is an Official Discord System user (part of the urgent message system)"), ) - public_flags: "UserFlags" = field( + public_flags: "UserFlags" = attrs.field( repr=True, default=0, converter=UserFlags, metadata=docs("The flags associated with this user") ) - premium_type: "PremiumTypes" = field( - default=0, converter=PremiumTypes, metadata=docs("The type of nitro subscription on a user's account") + premium_type: "PremiumTypes" = attrs.field( + repr=False, + default=0, + converter=PremiumTypes, + metadata=docs("The type of nitro subscription on a user's account"), ) - banner: Optional["Asset"] = field(default=None, metadata=docs("The user's banner")) - accent_color: Optional["Color"] = field( + banner: Optional["Asset"] = attrs.field(repr=False, default=None, metadata=docs("The user's banner")) + accent_color: Optional["Color"] = attrs.field( default=None, converter=optional_c(Color), metadata=docs("The user's banner color"), ) - activities: list[Activity] = field( + activities: list[Activity] = attrs.field( factory=list, converter=list_converter(optional(Activity.from_dict)), metadata=docs("A list of activities the user is in"), ) - status: Absent[Status] = field(default=MISSING, metadata=docs("The user's status"), converter=optional(Status)) + status: Absent[Status] = attrs.field( + repr=False, default=MISSING, metadata=docs("The user's status"), converter=optional(Status) + ) @classmethod def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: @@ -154,18 +161,26 @@ def member_instances(self) -> List["Member"]: return [member for member in member_objs if member] -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class NaffUser(User): - verified: bool = field(repr=True, metadata={"docs": "Whether the email on this account has been verified"}) - mfa_enabled: bool = field( - default=False, metadata={"docs": "Whether the user has two factor enabled on their account"} + verified: bool = attrs.field(repr=True, metadata={"docs": "Whether the email on this account has been verified"}) + mfa_enabled: bool = attrs.field( + repr=False, default=False, metadata={"docs": "Whether the user has two factor enabled on their account"} + ) + email: Optional[str] = attrs.field( + repr=False, default=None, metadata={"docs": "the user's email"} + ) # needs special permissions? + locale: Optional[str] = attrs.field( + repr=False, default=None, metadata={"docs": "the user's chosen language option"} + ) + bio: Optional[str] = attrs.field(repr=False, default=None, metadata={"docs": ""}) + flags: "UserFlags" = attrs.field( + repr=False, default=0, converter=UserFlags, metadata={"docs": "the flags on a user's account"} ) - email: Optional[str] = field(default=None, metadata={"docs": "the user's email"}) # needs special permissions? - locale: Optional[str] = field(default=None, metadata={"docs": "the user's chosen language option"}) - bio: Optional[str] = field(default=None, metadata={"docs": ""}) - flags: "UserFlags" = field(default=0, converter=UserFlags, metadata={"docs": "the flags on a user's account"}) - _guild_ids: Set["Snowflake_Type"] = field(factory=set, metadata={"docs": "All the guilds the user is in"}) + _guild_ids: Set["Snowflake_Type"] = attrs.field( + repr=False, factory=set, metadata={"docs": "All the guilds the user is in"} + ) def _add_guilds(self, guild_ids: Set["Snowflake_Type"]) -> None: """ @@ -182,7 +197,7 @@ def guilds(self) -> List["Guild"]: """The guilds the user is in.""" return [self._client.cache.get_guild(g_id) for g_id in self._guild_ids] - async def edit(self, username: Absent[str] = MISSING, avatar: Absent[UPLOADABLE_TYPE] = MISSING) -> None: + async def edit(self, *, username: Absent[str] = MISSING, avatar: Absent[UPLOADABLE_TYPE] = MISSING) -> None: """ Edit the client's user. @@ -224,33 +239,38 @@ async def edit(self, username: Absent[str] = MISSING, avatar: Absent[UPLOADABLE_ self._client.cache.place_user_data(resp) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Member(DiscordObject, _SendDMMixin): - bot: bool = field(repr=True, default=False, metadata=docs("Is this user a bot?")) - nick: Optional[str] = field(repr=True, default=None, metadata=docs("The user's nickname in this guild'")) - deaf: bool = field(default=False, metadata=docs("Has this user been deafened in voice channels?")) - mute: bool = field(default=False, metadata=docs("Has this user been muted in voice channels?")) - joined_at: "Timestamp" = field( - default=MISSING, converter=optional(timestamp_converter), metadata=docs("When the user joined this guild") + bot: bool = attrs.field(repr=True, default=False, metadata=docs("Is this user a bot?")) + nick: Optional[str] = attrs.field(repr=True, default=None, metadata=docs("The user's nickname in this guild'")) + deaf: bool = attrs.field(repr=False, default=False, metadata=docs("Has this user been deafened in voice channels?")) + mute: bool = attrs.field(repr=False, default=False, metadata=docs("Has this user been muted in voice channels?")) + joined_at: "Timestamp" = attrs.field( + repr=False, + default=MISSING, + converter=optional(timestamp_converter), + metadata=docs("When the user joined this guild"), ) - premium_since: Optional["Timestamp"] = field( + premium_since: Optional["Timestamp"] = attrs.field( default=None, converter=optional_c(timestamp_converter), metadata=docs("When the user started boosting the guild"), ) - pending: Optional[bool] = field( - default=None, metadata=docs("Whether the user has **not** passed guild's membership screening requirements") + pending: Optional[bool] = attrs.field( + repr=False, + default=None, + metadata=docs("Whether the user has **not** passed guild's membership screening requirements"), ) - guild_avatar: "Asset" = field(default=None, metadata=docs("The user's guild avatar")) - communication_disabled_until: Optional["Timestamp"] = field( + guild_avatar: "Asset" = attrs.field(repr=False, default=None, metadata=docs("The user's guild avatar")) + communication_disabled_until: Optional["Timestamp"] = attrs.field( default=None, converter=optional_c(timestamp_converter), metadata=docs("When a member's timeout will expire, `None` or a time in the past if the user is not timed out"), ) - _guild_id: "Snowflake_Type" = field(repr=True, metadata=docs("The ID of the guild")) - _role_ids: List["Snowflake_Type"] = field( - factory=list, converter=list_converter(to_snowflake), metadata=docs("The roles IDs this user has") + _guild_id: "Snowflake_Type" = attrs.field(repr=True, metadata=docs("The ID of the guild")) + _role_ids: List["Snowflake_Type"] = attrs.field( + repr=False, factory=list, converter=list_converter(to_snowflake), metadata=docs("The roles IDs this user has") ) @classmethod @@ -274,7 +294,7 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] client, f"guilds/{data['guild_id']}/users/{data['id']}/avatars/{{}}", data.pop("avatar", None) ) except Exception as e: - logger.warning( + client.logger.warning( f"[DEBUG NEEDED - REPORT THIS] Incomplete dictionary has been passed to member object: {e}" ) raise diff --git a/naff/models/discord/user.pyi b/naff/models/discord/user.pyi index c41730d5b..d347fb958 100644 --- a/naff/models/discord/user.pyi +++ b/naff/models/discord/user.pyi @@ -1,5 +1,6 @@ -from .base import DiscordObject from datetime import datetime +from typing import Iterable, List, Optional, Union, Set + from naff.client.const import Absent from naff.client.mixins.send import SendMixin from naff.models.discord.activity import Activity @@ -13,7 +14,7 @@ from naff.models.discord.role import Role from naff.models.discord.snowflake import Snowflake_Type from naff.models.discord.timestamp import Timestamp from naff.models.discord.voice_state import VoiceState -from typing import Iterable, List, Optional, Union, Set +from .base import DiscordObject class _SendDMMixin(SendMixin): id: Snowflake_Type diff --git a/naff/models/discord/voice_state.py b/naff/models/discord/voice_state.py index d98f412d3..3f9db7990 100644 --- a/naff/models/discord/voice_state.py +++ b/naff/models/discord/voice_state.py @@ -1,11 +1,12 @@ import copy from typing import TYPE_CHECKING, Optional, Dict, Any +import attrs + from naff.client.const import MISSING -from naff.client.utils.attr_utils import define, field +from naff.client.mixins.serialization import DictSerializationMixin from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter -from naff.client.mixins.serialization import DictSerializationMixin from naff.models.discord.snowflake import to_snowflake from naff.models.discord.timestamp import Timestamp from .base import ClientObject @@ -19,33 +20,35 @@ __all__ = ("VoiceState", "VoiceRegion") -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class VoiceState(ClientObject): - user_id: "Snowflake_Type" = field(default=MISSING, converter=to_snowflake) + user_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING, converter=to_snowflake) """the user id this voice state is for""" - session_id: str = field(default=MISSING) + session_id: str = attrs.field(repr=False, default=MISSING) """the session id for this voice state""" - deaf: bool = field(default=False) + deaf: bool = attrs.field(repr=False, default=False) """whether this user is deafened by the server""" - mute: bool = field(default=False) + mute: bool = attrs.field(repr=False, default=False) """whether this user is muted by the server""" - self_deaf: bool = field(default=False) + self_deaf: bool = attrs.field(repr=False, default=False) """whether this user is locally deafened""" - self_mute: bool = field(default=False) + self_mute: bool = attrs.field(repr=False, default=False) """whether this user is locally muted""" - self_stream: Optional[bool] = field(default=False) + self_stream: Optional[bool] = attrs.field(repr=False, default=False) """whether this user is streaming using "Go Live\"""" - self_video: bool = field(default=False) + self_video: bool = attrs.field(repr=False, default=False) """whether this user's camera is enabled""" - suppress: bool = field(default=False) + suppress: bool = attrs.field(repr=False, default=False) """whether this user is muted by the current user""" - request_to_speak_timestamp: Optional[Timestamp] = field(default=None, converter=optional_c(timestamp_converter)) + request_to_speak_timestamp: Optional[Timestamp] = attrs.field( + repr=False, default=None, converter=optional_c(timestamp_converter) + ) """the time at which the user requested to speak""" # internal for props - _guild_id: Optional["Snowflake_Type"] = field(default=None, converter=to_snowflake) - _channel_id: "Snowflake_Type" = field(converter=to_snowflake) - _member_id: Optional["Snowflake_Type"] = field(default=None, converter=to_snowflake) + _guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=to_snowflake) + _channel_id: "Snowflake_Type" = attrs.field(repr=False, converter=to_snowflake) + _member_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=to_snowflake) @property def guild(self) -> "Guild": @@ -91,21 +94,21 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class VoiceRegion(DictSerializationMixin): """A voice region.""" - id: str = field(repr=True) + id: str = attrs.field(repr=True) """unique ID for the region""" - name: str = field(repr=True) + name: str = attrs.field(repr=True) """name of the region""" - vip: bool = field(default=False, repr=True) + vip: bool = attrs.field(default=False, repr=True) """whether this is a VIP-only voice region""" - optimal: bool = field(default=False) + optimal: bool = attrs.field(repr=False, default=False) """true for a single server that is closest to the current user's client""" - deprecated: bool = field(default=False) + deprecated: bool = attrs.field(repr=False, default=False) """whether this is a deprecated voice region (avoid switching to these)""" - custom: bool = field(default=False) + custom: bool = attrs.field(repr=False, default=False) """whether this is a custom voice region (used for events/etc)""" def __str__(self) -> str: diff --git a/naff/models/discord/webhooks.py b/naff/models/discord/webhooks.py index dbda973f7..0ab1fedff 100644 --- a/naff/models/discord/webhooks.py +++ b/naff/models/discord/webhooks.py @@ -1,11 +1,12 @@ +import re from enum import IntEnum from typing import Optional, TYPE_CHECKING, Union, Dict, Any, List -import re + +import attrs from naff.client.const import MISSING, Absent from naff.client.errors import ForeignWebhookException, EmptyMessageException from naff.client.mixins.send import SendMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.serializer import to_image_data from naff.models.discord.message import process_message_payload from naff.models.discord.snowflake import to_snowflake, to_optional_snowflake @@ -39,33 +40,35 @@ class WebhookTypes(IntEnum): """Application webhooks are webhooks used with Interactions""" -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Webhook(DiscordObject, SendMixin): - type: WebhookTypes = field() + type: WebhookTypes = attrs.field( + repr=False, + ) """The type of webhook""" - application_id: Optional["Snowflake_Type"] = field(default=None) + application_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the bot/OAuth2 application that created this webhook""" - guild_id: Optional["Snowflake_Type"] = field(default=None) + guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the guild id this webhook is for, if any""" - channel_id: Optional["Snowflake_Type"] = field(default=None) + channel_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the channel id this webhook is for, if any""" - user_id: Optional["Snowflake_Type"] = field(default=None) + user_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the user this webhook was created by""" - name: Optional[str] = field(default=None) + name: Optional[str] = attrs.field(repr=False, default=None) """the default name of the webhook""" - avatar: Optional[str] = field(default=None) + avatar: Optional[str] = attrs.field(repr=False, default=None) """the default user avatar hash of the webhook""" - token: str = field(default=MISSING) + token: str = attrs.field(repr=False, default=MISSING) """the secure token of the webhook (returned for Incoming Webhooks)""" - url: Optional[str] = field(default=None) + url: Optional[str] = attrs.field(repr=False, default=None) """the url used for executing the webhook (returned by the webhooks OAuth2 flow)""" - source_guild_id: Optional["Snowflake_Type"] = field(default=None) + source_guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the guild of the channel that this webhook is following (returned for Channel Follower Webhooks)""" - source_channel_id: Optional["Snowflake_Type"] = field(default=None) + source_channel_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None) """the channel that this webhook is following (returned for Channel Follower Webhooks)""" @classmethod @@ -137,6 +140,7 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] async def edit( self, + *, name: Absent[str] = MISSING, avatar: Absent["UPLOADABLE_TYPE"] = MISSING, channel_id: Absent["Snowflake_Type"] = MISSING, @@ -168,6 +172,7 @@ async def delete(self) -> None: async def send( self, content: Optional[str] = None, + *, embed: Optional[Union["Embed", dict]] = None, embeds: Optional[Union[List[Union["Embed", dict]], Union["Embed", dict]]] = None, components: Optional[ @@ -246,6 +251,7 @@ async def send( async def edit_message( self, message: Union["Message", "Snowflake_Type"], + *, content: Optional[str] = None, embeds: Optional[Union[List[Union["Embed", dict]], Union["Embed", dict]]] = None, components: Optional[ diff --git a/naff/models/naff/active_voice_state.py b/naff/models/naff/active_voice_state.py index d474ab65a..869650845 100644 --- a/naff/models/naff/active_voice_state.py +++ b/naff/models/naff/active_voice_state.py @@ -1,14 +1,14 @@ import asyncio from typing import Optional, TYPE_CHECKING +import attrs from discord_typings import VoiceStateData from naff.api.voice.player import Player from naff.api.voice.voice_gateway import VoiceGateway -from naff.client.const import logger, MISSING +from naff.client.const import MISSING from naff.client.errors import VoiceAlreadyConnected, VoiceConnectionTimeout from naff.client.utils import optional -from naff.client.utils.attr_utils import define, field from naff.models.discord.snowflake import Snowflake_Type, to_snowflake from naff.models.discord.voice_state import VoiceState @@ -19,18 +19,18 @@ __all__ = ("ActiveVoiceState",) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ActiveVoiceState(VoiceState): - ws: Optional[VoiceGateway] = field(default=None) + ws: Optional[VoiceGateway] = attrs.field(repr=False, default=None) """The websocket for this voice state""" - player: Optional[Player] = field(default=None) + player: Optional[Player] = attrs.field(repr=False, default=None) """The playback task that broadcasts audio data to discord""" - _volume: float = field(default=0.5) + _volume: float = attrs.field(repr=False, default=0.5) # standard voice states expect this data, this voice state lacks it initially; so we make them optional - user_id: "Snowflake_Type" = field(default=MISSING, converter=optional(to_snowflake)) - _guild_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) - _member_id: Optional["Snowflake_Type"] = field(default=None, converter=optional(to_snowflake)) + user_id: "Snowflake_Type" = attrs.field(repr=False, default=MISSING, converter=optional(to_snowflake)) + _guild_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) + _member_id: Optional["Snowflake_Type"] = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) def __attrs_post_init__(self) -> None: # jank line to handle the two inherently incompatible data structures @@ -141,7 +141,7 @@ async def connect(self, timeout: int = 5) -> None: raise VoiceAlreadyConnected await self.gateway.voice_state_update(self._guild_id, self._channel_id, self.self_mute, self.self_deaf) - logger.debug("Waiting for voice connection data...") + self.logger.debug("Waiting for voice connection data...") try: self._voice_state, self._voice_server = await asyncio.gather( @@ -151,7 +151,7 @@ async def connect(self, timeout: int = 5) -> None: except asyncio.TimeoutError: raise VoiceConnectionTimeout from None - logger.debug("Attempting to initialise voice gateway...") + self.logger.debug("Attempting to initialise voice gateway...") await self.ws_connect() async def disconnect(self) -> None: @@ -176,7 +176,7 @@ async def move(self, channel: "Snowflake_Type", timeout: int = 5) -> None: self._channel_id = target_channel await self.gateway.voice_state_update(self._guild_id, self._channel_id, self.self_mute, self.self_deaf) - logger.debug("Waiting for voice connection data...") + self.logger.debug("Waiting for voice connection data...") try: await self._client.wait_for("raw_voice_state_update", self._guild_predicate, timeout=timeout) except asyncio.TimeoutError: @@ -246,7 +246,7 @@ async def _voice_state_update( """ if after is None: # bot disconnected - logger.info(f"Disconnecting from voice channel {self._channel_id}") + self.logger.info(f"Disconnecting from voice channel {self._channel_id}") await self._close_connection() self._client.cache.delete_bot_voice_state(self._guild_id) return diff --git a/naff/models/naff/annotations/argument.py b/naff/models/naff/annotations/argument.py index f2304c183..ff96110e5 100644 --- a/naff/models/naff/annotations/argument.py +++ b/naff/models/naff/annotations/argument.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING from naff.client.errors import BadArgument -from naff.models.naff.converters import NoArgumentConverter from naff.models.naff.context import Context, PrefixedContext +from naff.models.naff.converters import NoArgumentConverter __all__ = ("CMD_ARGS", "CMD_AUTHOR", "CMD_BODY", "CMD_CHANNEL") diff --git a/naff/models/naff/annotations/slash.py b/naff/models/naff/annotations/slash.py index f73c77e4a..4a9ff5bd4 100644 --- a/naff/models/naff/annotations/slash.py +++ b/naff/models/naff/annotations/slash.py @@ -1,7 +1,6 @@ from typing import Union, List, Optional, Type, TYPE_CHECKING import naff.models as models - from naff.models.naff.application_commands import SlashCommandOption __all__ = ( diff --git a/naff/models/naff/application_commands.py b/naff/models/naff/application_commands.py index 1960b3c1d..733556232 100644 --- a/naff/models/naff/application_commands.py +++ b/naff/models/naff/application_commands.py @@ -15,12 +15,11 @@ SLASH_CMD_MAX_OPTIONS, SLASH_CMD_MAX_DESC_LENGTH, MISSING, - logger, Absent, ) from naff.client.mixins.serialization import DictSerializationMixin from naff.client.utils import optional -from naff.client.utils.attr_utils import define, field, docs, attrs_validator +from naff.client.utils.attr_utils import attrs_validator, docs from naff.client.utils.misc_utils import get_parameters from naff.client.utils.serializer import no_export_meta from naff.models.discord.enums import ChannelTypes, CommandTypes, Permissions @@ -34,6 +33,7 @@ if TYPE_CHECKING: from naff.models.discord.snowflake import Snowflake_Type from naff.models.naff.context import Context + from naff import Client __all__ = ( "OptionTypes", @@ -63,7 +63,7 @@ def name_validator(_: Any, attr: Attribute, value: str) -> None: if value: - if not re.match(rf"^[\w-]{{1,{SLASH_CMD_NAME_LENGTH}}}$", value) or value != value.lower(): + if not re.match(rf"^[\w-]{{1,{SLASH_CMD_NAME_LENGTH}}}$", value) or value != value.lower(): # noqa: W605 raise ValueError( f"Slash Command names must be lower case and match this regex: ^[\w-]{1, {SLASH_CMD_NAME_LENGTH} }$" # noqa: W605 ) @@ -75,7 +75,9 @@ def desc_validator(_: Any, attr: Attribute, value: str) -> None: raise ValueError(f"Description must be between 1 and {SLASH_CMD_MAX_DESC_LENGTH} characters long") -@define(field_transformer=attrs_validator(name_validator, skip_fields=["default_locale"])) +@attrs.define( + eq=False, order=False, hash=False, field_transformer=attrs_validator(name_validator, skip_fields=["default_locale"]) +) class LocalisedName(LocalisedField): """A localisation object for names.""" @@ -83,7 +85,9 @@ def __repr__(self) -> str: return super().__repr__() -@define(field_transformer=attrs_validator(desc_validator, skip_fields=["default_locale"])) +@attrs.define( + eq=False, order=False, hash=False, field_transformer=attrs_validator(desc_validator, skip_fields=["default_locale"]) +) class LocalisedDesc(LocalisedField): """A localisation object for descriptions.""" @@ -150,7 +154,7 @@ class CallbackTypes(IntEnum): MODAL = 9 -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class InteractionCommand(BaseCommand): """ Represents a discord abstract interaction command. @@ -164,30 +168,32 @@ class InteractionCommand(BaseCommand): """ - name: LocalisedName = field( - metadata=docs("1-32 character name") | no_export_meta, converter=LocalisedName.converter + name: LocalisedName | str = attrs.field( + repr=False, metadata=docs("1-32 character name") | no_export_meta, converter=LocalisedName.converter ) - scopes: List["Snowflake_Type"] = field( + scopes: List["Snowflake_Type"] = attrs.field( default=[GLOBAL_SCOPE], converter=to_snowflake_list, metadata=docs("The scopes of this interaction. Global or guild ids") | no_export_meta, ) - default_member_permissions: Optional["Permissions"] = field( - default=None, metadata=docs("What permissions members need to have by default to use this command") + default_member_permissions: Optional["Permissions"] = attrs.field( + repr=False, default=None, metadata=docs("What permissions members need to have by default to use this command") ) - dm_permission: bool = field(default=True, metadata=docs("Whether this command is enabled in DMs")) - cmd_id: Dict[str, "Snowflake_Type"] = field( - factory=dict, metadata=docs("The unique IDs of this commands") | no_export_meta + dm_permission: bool = attrs.field(repr=False, default=True, metadata=docs("Whether this command is enabled in DMs")) + cmd_id: Dict[str, "Snowflake_Type"] = attrs.field( + repr=False, factory=dict, metadata=docs("The unique IDs of this commands") | no_export_meta ) # scope: cmd_id - callback: Callable[..., Coroutine] = field( - default=None, metadata=docs("The coroutine to call when this interaction is received") | no_export_meta + callback: Callable[..., Coroutine] = attrs.field( + repr=False, + default=None, + metadata=docs("The coroutine to call when this interaction is received") | no_export_meta, ) - auto_defer: "AutoDefer" = field( + auto_defer: "AutoDefer" = attrs.field( default=MISSING, metadata=docs("A system to automatically defer this command after a set duration") | no_export_meta, ) - nsfw: bool = field(default=False, metadata=docs("This command should only work in NSFW channels")) - _application_id: "Snowflake_Type" = field(default=None, converter=optional(to_snowflake)) + nsfw: bool = attrs.field(repr=False, default=False, metadata=docs("This command should only work in NSFW channels")) + _application_id: "Snowflake_Type" = attrs.field(repr=False, default=None, converter=optional(to_snowflake)) def __attrs_post_init__(self) -> None: if self.callback is not None: @@ -245,8 +251,35 @@ async def _permission_enforcer(self, ctx: "Context") -> bool: return ctx.guild is not None return True + def is_enabled(self, ctx: "Context") -> bool: + """ + Check if this command is enabled in the given context. + + Args: + ctx: The context to check. + + Returns: + Whether this command is enabled in the given context. + """ + if not self.dm_permission and ctx.guild is None: + return False + elif self.dm_permission and ctx.guild is None: + # remaining checks are impossible if this is a DM and DMs are enabled + return True + + if self.nsfw and not ctx.channel.is_nsfw(): + return False + if cmd_perms := ctx.guild.command_permissions.get(self.get_cmd_id(ctx.guild.id)): + if not cmd_perms.is_enabled_in_context(ctx): + return False + if self.default_member_permissions is not None: + channel_perms = ctx.author.channel_permissions(ctx.channel) + if any(perm not in channel_perms for perm in self.default_member_permissions): + return False + return True + -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ContextMenu(InteractionCommand): """ Represents a discord context menu. @@ -257,8 +290,10 @@ class ContextMenu(InteractionCommand): """ - name: LocalisedField = field(metadata=docs("1-32 character name"), converter=LocalisedField.converter) - type: CommandTypes = field(metadata=docs("The type of command, defaults to 1 if not specified")) + name: LocalisedField = attrs.field( + repr=False, metadata=docs("1-32 character name"), converter=LocalisedField.converter + ) + type: CommandTypes = attrs.field(repr=False, metadata=docs("The type of command, defaults to 1 if not specified")) @type.validator def _type_validator(self, attribute: str, value: int) -> None: @@ -277,7 +312,7 @@ def to_dict(self) -> dict: return data -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class SlashCommandChoice(DictSerializationMixin): """ Represents a discord slash command choice. @@ -288,14 +323,16 @@ class SlashCommandChoice(DictSerializationMixin): """ - name: LocalisedField = field(converter=LocalisedField.converter) - value: Union[str, int, float] = field() + name: LocalisedField | str = attrs.field(repr=False, converter=LocalisedField.converter) + value: Union[str, int, float] = attrs.field( + repr=False, + ) def as_dict(self) -> dict: return {"name": str(self.name), "value": self.value, "name_localizations": self.name.to_locale_dict()} -@define(kw_only=False) +@attrs.define(eq=False, order=False, hash=False, kw_only=False) class SlashCommandOption(DictSerializationMixin): """ Represents a discord slash command option. @@ -314,17 +351,21 @@ class SlashCommandOption(DictSerializationMixin): """ - name: LocalisedName = field(converter=LocalisedName.converter) - type: Union[OptionTypes, int] = field() - description: LocalisedDesc = field(default="No Description Set", converter=LocalisedDesc.converter) - required: bool = field(default=True) - autocomplete: bool = field(default=False) - choices: List[Union[SlashCommandChoice, Dict]] = field(factory=list) - channel_types: Optional[list[Union[ChannelTypes, int]]] = field(default=None) - min_value: Optional[float] = field(default=None) - max_value: Optional[float] = field(default=None) - min_length: Optional[int] = field(default=None) - max_length: Optional[int] = field(default=None) + name: LocalisedName | str = attrs.field(repr=False, converter=LocalisedName.converter) + type: Union[OptionTypes, int] = attrs.field( + repr=False, + ) + description: LocalisedDesc | str = attrs.field( + repr=False, default="No Description Set", converter=LocalisedDesc.converter + ) + required: bool = attrs.field(repr=False, default=True) + autocomplete: bool = attrs.field(repr=False, default=False) + choices: List[Union[SlashCommandChoice, Dict]] = attrs.field(repr=False, factory=list) + channel_types: Optional[list[Union[ChannelTypes, int]]] = attrs.field(repr=False, default=None) + min_value: Optional[float] = attrs.field(repr=False, default=None) + max_value: Optional[float] = attrs.field(repr=False, default=None) + min_length: Optional[int] = attrs.field(repr=False, default=None) + max_length: Optional[int] = attrs.field(repr=False, default=None) @type.validator def _type_validator(self, attribute: str, value: int) -> None: @@ -412,23 +453,29 @@ def as_dict(self) -> dict: return data -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class SlashCommand(InteractionCommand): - name: LocalisedName = field(converter=LocalisedName.converter) - description: LocalisedDesc = field(default="No Description Set", converter=LocalisedDesc.converter) + name: LocalisedName | str = attrs.field(repr=False, converter=LocalisedName.converter) + description: LocalisedDesc = attrs.field( + repr=False, default="No Description Set", converter=LocalisedDesc.converter + ) - group_name: LocalisedName = field(default=None, metadata=no_export_meta, converter=LocalisedName.converter) - group_description: LocalisedDesc = field( - default="No Description Set", metadata=no_export_meta, converter=LocalisedDesc.converter + group_name: LocalisedName | str = attrs.field( + repr=False, default=None, metadata=no_export_meta, converter=LocalisedName.converter + ) + group_description: LocalisedDesc = attrs.field( + repr=False, default="No Description Set", metadata=no_export_meta, converter=LocalisedDesc.converter ) - sub_cmd_name: LocalisedName = field(default=None, metadata=no_export_meta, converter=LocalisedName.converter) - sub_cmd_description: LocalisedDesc = field( - default="No Description Set", metadata=no_export_meta, converter=LocalisedDesc.converter + sub_cmd_name: LocalisedName | str = attrs.field( + repr=False, default=None, metadata=no_export_meta, converter=LocalisedName.converter + ) + sub_cmd_description: LocalisedDesc = attrs.field( + repr=False, default="No Description Set", metadata=no_export_meta, converter=LocalisedDesc.converter ) - options: List[Union[SlashCommandOption, Dict]] = field(factory=list) - autocomplete_callbacks: dict = field(factory=dict, metadata=no_export_meta) + options: List[Union[SlashCommandOption, Dict]] = attrs.field(repr=False, factory=list) + autocomplete_callbacks: dict = attrs.field(repr=False, factory=dict, metadata=no_export_meta) @property def resolved_name(self) -> str: @@ -572,14 +619,16 @@ def wrapper(call: Callable[..., Coroutine]) -> "SlashCommand": return wrapper -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ComponentCommand(InteractionCommand): # right now this adds no extra functionality, but for future dev ive implemented it - name: str = field() - listeners: list[str] = field(factory=list) + name: str = attrs.field( + repr=False, + ) + listeners: list[str] = attrs.field(repr=False, factory=list) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ModalCommand(ComponentCommand): ... @@ -935,7 +984,9 @@ def wrapper(func: Coroutine) -> Coroutine: return wrapper -def application_commands_to_dict(commands: Dict["Snowflake_Type", Dict[str, InteractionCommand]]) -> dict: +def application_commands_to_dict( + commands: Dict["Snowflake_Type", Dict[str, InteractionCommand]], client: "Client" +) -> dict: """ Convert the command list into a format that would be accepted by discord. @@ -1008,7 +1059,7 @@ def squash_subcommand(subcommands: List) -> Dict: nsfw = cmd_list[0].nsfw if not all(str(c.description) in (str(base_description), "No Description Set") for c in cmd_list): - logger.warning( + client.logger.warning( f"Conflicting descriptions found in `{cmd_list[0].name}` subcommands; `{str(base_description)}` will be used" ) if not all(c.default_member_permissions == cmd_list[0].default_member_permissions for c in cmd_list): @@ -1016,7 +1067,7 @@ def squash_subcommand(subcommands: List) -> Dict: if not all(c.dm_permission == cmd_list[0].dm_permission for c in cmd_list): raise ValueError(f"Conflicting `dm_permission` values found in `{cmd_list[0].name}`") if not all(c.nsfw == nsfw for c in cmd_list): - logger.warning(f"Conflicting `nsfw` values found in {cmd_list[0].name} - `True` will be used") + client.logger.warning(f"Conflicting `nsfw` values found in {cmd_list[0].name} - `True` will be used") nsfw = True for cmd in cmd_list: @@ -1054,6 +1105,9 @@ def _compare_commands(local_cmd: dict, remote_cmd: dict) -> bool: "name_localized": ("name_localizations", None), "description_localized": ("description_localizations", None), } + if remote_cmd.get("guild_id"): + # non-global command + del lookup["dm_permission"] for local_name, comparison_data in lookup.items(): remote_name, default_value = comparison_data @@ -1074,7 +1128,7 @@ def _compare_options(local_opt_list: dict, remote_opt_list: dict) -> bool: "max_value": ("max_value", None), "min_value": ("min_value", None), "max_length": ("max_length", None), - "min_length": ("max_length", None), + "min_length": ("min_length", None), } post_process: Dict[str, Callable] = { "choices": lambda l: [d | {"name_localizations": {}} if len(d) == 2 else d for d in l], diff --git a/naff/models/naff/auto_defer.py b/naff/models/naff/auto_defer.py index ec6c738d3..baf17245b 100644 --- a/naff/models/naff/auto_defer.py +++ b/naff/models/naff/auto_defer.py @@ -1,8 +1,9 @@ import asyncio from typing import TYPE_CHECKING +import attrs + from naff.client.errors import AlreadyDeferred, NotFound, BadRequest, HTTPException -from naff.client.utils.attr_utils import define, field if TYPE_CHECKING: from naff.models.naff.context import InteractionContext @@ -10,17 +11,17 @@ __all__ = ("AutoDefer",) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AutoDefer: """Automatically defer application commands for you!""" - enabled: bool = field(default=False) + enabled: bool = attrs.field(repr=False, default=False) """Whether or not auto-defer is enabled""" - ephemeral: bool = field(default=False) + ephemeral: bool = attrs.field(repr=False, default=False) """Should the command be deferred as ephemeral or not""" - time_until_defer: float = field(default=1.5) + time_until_defer: float = attrs.field(repr=False, default=1.5) """How long to wait before automatically deferring""" async def __call__(self, ctx: "InteractionContext") -> None: diff --git a/naff/models/naff/checks.py b/naff/models/naff/checks.py index 07cf38294..72bd16cf7 100644 --- a/naff/models/naff/checks.py +++ b/naff/models/naff/checks.py @@ -1,11 +1,9 @@ from typing import Awaitable, Callable -from naff.models.discord.snowflake import Snowflake_Type, to_snowflake -from naff.models.naff.context import Context - from naff.models.discord.role import Role +from naff.models.discord.snowflake import Snowflake_Type, to_snowflake from naff.models.discord.user import Member - +from naff.models.naff.context import Context __all__ = ("has_role", "has_any_role", "has_id", "is_owner", "guild_only", "dm_only") diff --git a/naff/models/naff/command.py b/naff/models/naff/command.py index 99ba87478..c359e1ebe 100644 --- a/naff/models/naff/command.py +++ b/naff/models/naff/command.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio import copy import functools @@ -6,13 +7,15 @@ import typing from typing import Annotated, Awaitable, Callable, Coroutine, Optional, Tuple, Any, TYPE_CHECKING -from naff.models.naff.callback import CallbackObject +import attrs + from naff.client.const import MISSING from naff.client.errors import CommandOnCooldown, CommandCheckFailure, MaxConcurrencyReached from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field, docs +from naff.client.utils.attr_utils import docs from naff.client.utils.misc_utils import get_parameters, get_object_name, maybe_coroutine from naff.client.utils.serializer import no_export_meta +from naff.models.naff.callback import CallbackObject from naff.models.naff.cooldowns import Cooldown, Buckets, MaxConcurrency from naff.models.naff.protocols import Converter @@ -26,7 +29,7 @@ args_reg = re.compile(r"^\*\w") -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class BaseCommand(DictSerializationMixin, CallbackObject): """ An object all commands inherit from. Outlines the basic structure of a command, and handles checks. @@ -42,33 +45,41 @@ class BaseCommand(DictSerializationMixin, CallbackObject): """ - extension: Any = field(default=None, metadata=docs("The extension this command belongs to") | no_export_meta) + extension: Any = attrs.field( + repr=False, default=None, metadata=docs("The extension this command belongs to") | no_export_meta + ) - enabled: bool = field(default=True, metadata=docs("Whether this can be run at all") | no_export_meta) - checks: list = field( - factory=list, metadata=docs("Any checks that must be *checked* before the command can run") | no_export_meta + enabled: bool = attrs.field( + repr=False, default=True, metadata=docs("Whether this can be run at all") | no_export_meta + ) + checks: list = attrs.field( + repr=False, + factory=list, + metadata=docs("Any checks that must be *checked* before the command can run") | no_export_meta, ) - cooldown: Cooldown = field( - default=MISSING, metadata=docs("An optional cooldown to apply to the command") | no_export_meta + cooldown: Cooldown = attrs.field( + repr=False, default=MISSING, metadata=docs("An optional cooldown to apply to the command") | no_export_meta ) - max_concurrency: MaxConcurrency = field( + max_concurrency: MaxConcurrency = attrs.field( default=MISSING, metadata=docs("An optional maximum number of concurrent instances to apply to the command") | no_export_meta, ) - callback: Callable[..., Coroutine] = field( - default=None, metadata=docs("The coroutine to be called for this command") | no_export_meta + callback: Callable[..., Coroutine] = attrs.field( + repr=False, default=None, metadata=docs("The coroutine to be called for this command") | no_export_meta ) - error_callback: Callable[..., Coroutine] = field( - default=None, metadata=no_export_meta | docs("The coroutine to be called when an error occurs") + error_callback: Callable[..., Coroutine] = attrs.field( + repr=False, default=None, metadata=no_export_meta | docs("The coroutine to be called when an error occurs") ) - pre_run_callback: Callable[..., Coroutine] = field( + pre_run_callback: Callable[..., Coroutine] = attrs.field( default=None, metadata=no_export_meta | docs("The coroutine to be called before the command is executed, **but** after the checks"), ) - post_run_callback: Callable[..., Coroutine] = field( - default=None, metadata=no_export_meta | docs("The coroutine to be called after the command has executed") + post_run_callback: Callable[..., Coroutine] = attrs.field( + repr=False, + default=None, + metadata=no_export_meta | docs("The coroutine to be called after the command has executed"), ) def __attrs_post_init__(self) -> None: diff --git a/naff/models/naff/context.py b/naff/models/naff/context.py index 52b06b918..eed3469a7 100644 --- a/naff/models/naff/context.py +++ b/naff/models/naff/context.py @@ -1,19 +1,22 @@ import datetime -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Union, runtime_checkable +from logging import Logger +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Union, runtime_checkable, Sequence +import attrs from aiohttp import FormData +import naff.models as models import naff.models.discord.message as message -from naff.models.discord.timestamp import Timestamp -from naff.client.const import MISSING, logger, Absent +from naff.client.const import Absent, MISSING, get_logger from naff.client.errors import AlreadyDeferred from naff.client.mixins.send import SendMixin -from naff.client.utils.attr_utils import define, field, docs from naff.client.utils.attr_converters import optional +from naff.client.utils.attr_utils import docs from naff.models.discord.enums import MessageFlags, CommandTypes, Permissions from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.discord.message import Attachment from naff.models.discord.snowflake import to_snowflake, to_optional_snowflake +from naff.models.discord.timestamp import Timestamp from naff.models.naff.application_commands import CallbackTypes, OptionTypes if TYPE_CHECKING: @@ -49,27 +52,27 @@ ) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Resolved: """Represents resolved data in an interaction.""" - channels: Dict["Snowflake_Type", "TYPE_MESSAGEABLE_CHANNEL"] = field( - factory=dict, metadata=docs("A dictionary of channels mentioned in the interaction") + channels: Dict["Snowflake_Type", "TYPE_MESSAGEABLE_CHANNEL"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of channels mentioned in the interaction") ) - members: Dict["Snowflake_Type", "Member"] = field( - factory=dict, metadata=docs("A dictionary of members mentioned in the interaction") + members: Dict["Snowflake_Type", "Member"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of members mentioned in the interaction") ) - users: Dict["Snowflake_Type", "User"] = field( - factory=dict, metadata=docs("A dictionary of users mentioned in the interaction") + users: Dict["Snowflake_Type", "User"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of users mentioned in the interaction") ) - roles: Dict["Snowflake_Type", "Role"] = field( - factory=dict, metadata=docs("A dictionary of roles mentioned in the interaction") + roles: Dict["Snowflake_Type", "Role"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of roles mentioned in the interaction") ) - messages: Dict["Snowflake_Type", "Message"] = field( - factory=dict, metadata=docs("A dictionary of messages mentioned in the interaction") + messages: Dict["Snowflake_Type", "Message"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of messages mentioned in the interaction") ) - attachments: Dict["Snowflake_Type", "Attachment"] = field( - factory=dict, metadata=docs("A dictionary of attachments tied to the interaction") + attachments: Dict["Snowflake_Type", "Attachment"] = attrs.field( + repr=False, factory=dict, metadata=docs("A dictionary of attachments tied to the interaction") ) @classmethod @@ -105,23 +108,34 @@ def from_dict(cls, client: "Client", data: dict, guild_id: Optional["Snowflake_T return new_cls -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class Context: """Represents the context of a command.""" - _client: "Client" = field(default=None) - invoke_target: str = field(default=None, metadata=docs("The name of the command to be invoked")) - command: Optional["BaseCommand"] = field(default=None, metadata=docs("The command to be invoked")) + _client: "Client" = attrs.field(repr=False, default=None) + invoke_target: str = attrs.field(repr=False, default=None, metadata=docs("The name of the command to be invoked")) + command: Optional["BaseCommand"] = attrs.field(repr=False, default=None, metadata=docs("The command to be invoked")) - args: List = field(factory=list, metadata=docs("The list of arguments to be passed to the command")) - kwargs: Dict = field(factory=dict, metadata=docs("The list of keyword arguments to be passed")) + args: List = attrs.field( + repr=False, factory=list, metadata=docs("The list of arguments to be passed to the command") + ) + kwargs: Dict = attrs.field(repr=False, factory=dict, metadata=docs("The list of keyword arguments to be passed")) - author: Union["Member", "User"] = field(default=None, metadata=docs("The author of the message")) - channel: "TYPE_MESSAGEABLE_CHANNEL" = field(default=None, metadata=docs("The channel this was sent within")) - guild_id: "Snowflake_Type" = field( - default=None, converter=to_optional_snowflake, metadata=docs("The guild this was sent within, if not a DM") + author: Union["Member", "User"] = attrs.field(repr=False, default=None, metadata=docs("The author of the message")) + channel: "TYPE_MESSAGEABLE_CHANNEL" = attrs.field( + repr=False, default=None, metadata=docs("The channel this was sent within") + ) + guild_id: "Snowflake_Type" = attrs.field( + repr=False, + default=None, + converter=to_optional_snowflake, + metadata=docs("The guild this was sent within, if not a DM"), + ) + message: "Message" = attrs.field( + repr=False, default=None, metadata=docs("The message associated with this context") ) - message: "Message" = field(default=None, metadata=docs("The message associated with this context")) + + logger: Logger = attrs.field(repr=False, init=False, factory=get_logger) @property def guild(self) -> Optional["Guild"]: @@ -137,36 +151,44 @@ def voice_state(self) -> Optional["ActiveVoiceState"]: return self._client.cache.get_bot_voice_state(self.guild_id) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class _BaseInteractionContext(Context): """An internal object used to define the attributes of interaction context and its children.""" - _token: str = field(default=None, metadata=docs("The token for the interaction")) - _context_type: int = field() # we don't want to convert this in case of a new context type, which is expected - interaction_id: str = field(default=None, metadata=docs("The id of the interaction")) - target_id: "Snowflake_Type" = field( + _token: str = attrs.field(repr=False, default=None, metadata=docs("The token for the interaction")) + _context_type: int = attrs.field( + repr=False, + ) # we don't want to convert this in case of a new context type, which is expected + interaction_id: "Snowflake_Type" = attrs.field( + repr=False, default=None, metadata=docs("The id of the interaction"), converter=to_snowflake + ) + target_id: "Snowflake_Type" = attrs.field( default=None, metadata=docs("The ID of the target, used for context menus to show what was clicked on"), converter=optional(to_snowflake), ) - app_permissions: Permissions = field( - default=0, converter=Permissions, metadata=docs("The permissions this interaction has") + app_permissions: Permissions = attrs.field( + repr=False, default=0, converter=Permissions, metadata=docs("The permissions this interaction has") ) - locale: str = field( + locale: str = attrs.field( default=None, metadata=docs( "The selected language of the invoking user \n(https://discord.com/developers/docs/reference#locales)" ), ) - guild_locale: str = field(default=None, metadata=docs("The guild's preferred locale")) + guild_locale: str = attrs.field(repr=False, default=None, metadata=docs("The guild's preferred locale")) - deferred: bool = field(default=False, metadata=docs("Is this interaction deferred?")) - responded: bool = field(default=False, metadata=docs("Have we responded to the interaction?")) - ephemeral: bool = field(default=False, metadata=docs("Are responses to this interaction *hidden*")) + deferred: bool = attrs.field(repr=False, default=False, metadata=docs("Is this interaction deferred?")) + responded: bool = attrs.field(repr=False, default=False, metadata=docs("Have we responded to the interaction?")) + ephemeral: bool = attrs.field( + repr=False, default=False, metadata=docs("Are responses to this interaction *hidden*") + ) - resolved: Resolved = field(default=Resolved(), metadata=docs("Discord objects mentioned within this interaction")) + resolved: Resolved = attrs.field( + repr=False, default=Resolved(), metadata=docs("Discord objects mentioned within this interaction") + ) - data: Dict = field(factory=dict, metadata=docs("The raw data of this interaction")) + data: Dict = attrs.field(repr=False, factory=dict, metadata=docs("The raw data of this interaction")) @classmethod def from_dict(cls, data: Dict, client: "Client") -> "Context": @@ -286,7 +308,7 @@ async def send_modal(self, modal: Union[dict, "Modal"]) -> Union[dict, "Modal"]: return modal -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class InteractionContext(_BaseInteractionContext, SendMixin): """ Represents the context of an interaction. @@ -345,6 +367,7 @@ async def _send_http_request( async def send( self, content: Optional[str] = None, + *, embeds: Optional[Union[Iterable[Union["Embed", dict]], Union["Embed", dict]]] = None, embed: Optional[Union["Embed", dict]] = None, components: Optional[ @@ -410,6 +433,62 @@ async def send( flags=flags, ) + async def delete(self, message: "Snowflake_Type") -> None: + """ + Delete a message sent in response to this interaction. + + Args: + message: The message to delete + + """ + await self._client.http.delete_interaction_message(self._client.app.id, self._token, to_snowflake(message)) + + async def edit( + self, + message: "Snowflake_Type", + *, + content: Optional[str] = None, + embeds: Optional[Union[Sequence[Union["models.Embed", dict]], Union["models.Embed", dict]]] = None, + embed: Optional[Union["models.Embed", dict]] = None, + components: Optional[ + Union[ + Sequence[Sequence[Union["models.BaseComponent", dict]]], + Sequence[Union["models.BaseComponent", dict]], + "models.BaseComponent", + dict, + ] + ] = None, + allowed_mentions: Optional[Union["models.AllowedMentions", dict]] = None, + attachments: Optional[Optional[Sequence[Union[Attachment, dict]]]] = None, + files: Optional[Union[UPLOADABLE_TYPE, Sequence[UPLOADABLE_TYPE]]] = None, + file: Optional[UPLOADABLE_TYPE] = None, + tts: bool = False, + ) -> "models.Message": + message_payload = models.process_message_payload( + content=content, + embeds=embeds or embed, + components=components, + allowed_mentions=allowed_mentions, + attachments=attachments, + tts=tts, + ) + + if file: + if files: + files = [file, *files] + else: + files = [file] + + message_data = await self._client.http.edit_interaction_message( + payload=message_payload, + application_id=self._client.app.id, + token=self._token, + message_id=to_snowflake(message), + files=files, + ) + if message_data: + return self._client.cache.place_message_data(message_data) + @property def target(self) -> "Absent[Member | User | Message]": """For context menus, this will be the object of which was clicked on.""" @@ -422,21 +501,21 @@ def target(self) -> "Absent[Member | User | Message]": # This can only be in the member or user cache caches = ( (self._client.cache.get_member, (self.guild_id, self.target_id)), - (self._client.cache.get_user, self.target_id), + (self._client.cache.get_user, (self.target_id,)), ) case CommandTypes.MESSAGE: # This can only be in the message cache caches = ((self._client.cache.get_message, (self.channel.id, self.target_id)),) case _: # Most likely a new context type, check all rational caches for the target_id - logger.warning(f"New Context Type Detected. Please Report: {self._context_type}") + self.logger.warning(f"New Context Type Detected. Please Report: {self._context_type}") caches = ( (self._client.cache.get_message, (self.channel.id, self.target_id)), (self._client.cache.get_member, (self.guild_id, self.target_id)), - (self._client.cache.get_user, self.target_id), - (self._client.cache.get_channel, self.target_id), - (self._client.cache.get_role, self.target_id), - (self._client.cache.get_emoji, self.target_id), # unlikely, so check last + (self._client.cache.get_user, (self.target_id,)), + (self._client.cache.get_channel, (self.target_id,)), + (self._client.cache.get_role, (self.target_id,)), + (self._client.cache.get_emoji, (self.target_id,)), # unlikely, so check last ) for cache, keys in caches: @@ -446,14 +525,20 @@ def target(self) -> "Absent[Member | User | Message]": return thing -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ComponentContext(InteractionContext): - custom_id: str = field(default="", metadata=docs("The ID given to the component that has been pressed")) - component_type: int = field(default=0, metadata=docs("The type of component that has been pressed")) + custom_id: str = attrs.field( + repr=False, default="", metadata=docs("The ID given to the component that has been pressed") + ) + component_type: int = attrs.field( + repr=False, default=0, metadata=docs("The type of component that has been pressed") + ) - values: List = field(factory=list, metadata=docs("The values set")) + values: List = attrs.field(repr=False, factory=list, metadata=docs("The values set")) - defer_edit_origin: bool = field(default=False, metadata=docs("Are we editing the message the component is on")) + defer_edit_origin: bool = attrs.field( + repr=False, default=False, metadata=docs("Are we editing the message the component is on") + ) @classmethod def from_dict(cls, data: Dict, client: "Client") -> "ComponentContext": @@ -461,6 +546,7 @@ def from_dict(cls, data: Dict, client: "Client") -> "ComponentContext": new_cls = super().from_dict(data, client) new_cls.token = data["token"] new_cls.interaction_id = data["id"] + new_cls.invoke_target = data["data"]["custom_id"] new_cls.custom_id = data["data"]["custom_id"] new_cls.component_type = data["data"]["component_type"] new_cls.message = client.cache.place_message_data(data["message"]) @@ -498,6 +584,7 @@ async def defer(self, ephemeral: bool = False, edit_origin: bool = False) -> Non async def edit_origin( self, + *, content: str = None, embeds: Optional[Union[Iterable[Union["Embed", dict]], Union["Embed", dict]]] = None, embed: Optional[Union["Embed", dict]] = None, @@ -546,7 +633,7 @@ async def edit_origin( message_data = None if self.deferred: if not self.defer_edit_origin: - logger.warning( + self.logger.warning( "If you want to edit the original message, and need to defer, you must set the `edit_origin` kwarg to True!" ) @@ -567,9 +654,11 @@ async def edit_origin( return self.message -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class AutocompleteContext(_BaseInteractionContext): - focussed_option: str = field(default=MISSING, metadata=docs("The option the user is currently filling in")) + focussed_option: str = attrs.field( + repr=False, default=MISSING, metadata=docs("The option the user is currently filling in") + ) @classmethod def from_dict(cls, data: Dict, client: "Client") -> "ComponentContext": @@ -614,9 +703,9 @@ async def send(self, choices: Iterable[Union[str, int, float, Dict[str, Union[st await self._client.http.post_initial_response(payload, self.interaction_id, self._token) -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class ModalContext(InteractionContext): - custom_id: str = field(default="") + custom_id: str = attrs.field(repr=False, default="") @classmethod def from_dict(cls, data: Dict, client: "Client") -> "ModalContext": @@ -639,9 +728,9 @@ def responses(self) -> dict[str, str]: return self.kwargs -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class PrefixedContext(Context, SendMixin): - prefix: str = field(default=MISSING, metadata=docs("The prefix used to invoke this command")) + prefix: str = attrs.field(repr=False, default=MISSING, metadata=docs("The prefix used to invoke this command")) @classmethod def from_message(cls, client: "Client", message: "Message") -> "PrefixedContext": @@ -674,7 +763,7 @@ async def _send_http_request( return await self._client.http.create_message(message_payload, self.channel.id, files=files) -@define +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class HybridContext(Context): """ Represents the context for hybrid commands, a slash command that can also be used as a prefixed command. @@ -682,14 +771,14 @@ class HybridContext(Context): This attempts to create a compatibility layer to allow contexts for an interaction or a message to be used seamlessly. """ - deferred: bool = field(default=False, metadata=docs("Is this context deferred?")) - responded: bool = field(default=False, metadata=docs("Have we responded to this?")) - app_permissions: Permissions = field( - default=0, converter=Permissions, metadata=docs("The permissions this context has") + deferred: bool = attrs.field(repr=False, default=False, metadata=docs("Is this context deferred?")) + responded: bool = attrs.field(repr=False, default=False, metadata=docs("Have we responded to this?")) + app_permissions: Permissions = attrs.field( + repr=False, default=0, converter=Permissions, metadata=docs("The permissions this context has") ) - _interaction_context: Optional[InteractionContext] = field(default=None) - _prefixed_context: Optional[PrefixedContext] = field(default=None) + _interaction_context: Optional[InteractionContext] = attrs.field(repr=False, default=None) + _prefixed_context: Optional[PrefixedContext] = attrs.field(repr=False, default=None) @classmethod def from_interaction_context(cls, context: InteractionContext) -> "HybridContext": @@ -815,6 +904,7 @@ async def reply( async def send( self, content: Optional[str] = None, + *, embeds: Optional[Union[Iterable[Union["Embed", dict]], Union["Embed", dict]]] = None, embed: Optional[Union["Embed", dict]] = None, components: Optional[ @@ -897,7 +987,9 @@ def guild(self) -> Optional["Guild"]: async def send( self, content: Optional[str] = None, + *, embeds: Optional[Union[Iterable[Union["Embed", dict]], Union["Embed", dict]]] = None, + embed: Optional[Union["Embed", dict]] = None, components: Optional[ Union[ Iterable[Iterable[Union["BaseComponent", dict]]], @@ -909,9 +1001,12 @@ async def send( stickers: Optional[Union[Iterable[Union["Sticker", "Snowflake_Type"]], "Sticker", "Snowflake_Type"]] = None, allowed_mentions: Optional[Union["AllowedMentions", dict]] = None, reply_to: Optional[Union["MessageReference", "Message", dict, "Snowflake_Type"]] = None, - file: Optional[Union["File", "IOBase", "Path", str]] = None, + files: Optional[Union["UPLOADABLE_TYPE", Iterable["UPLOADABLE_TYPE"]]] = None, + file: Optional["UPLOADABLE_TYPE"] = None, tts: bool = False, + suppress_embeds: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, + delete_after: Optional[float] = None, **kwargs: Any, ) -> "Message": ... diff --git a/naff/models/naff/converters.py b/naff/models/naff/converters.py index 02725d199..83d77ff3c 100644 --- a/naff/models/naff/converters.py +++ b/naff/models/naff/converters.py @@ -3,13 +3,8 @@ from typing import Any, Optional, List from naff.client.const import T, T_co +from naff.client.errors import BadArgument from naff.client.errors import Forbidden, HTTPException -from naff.models.discord.role import Role -from naff.models.discord.guild import Guild -from naff.models.discord.message import Message -from naff.models.discord.user import User, Member -from naff.models.discord.snowflake import SnowflakeObject -from naff.models.discord.emoji import PartialEmoji, CustomEmoji from naff.models.discord.channel import ( BaseChannel, DMChannel, @@ -33,10 +28,14 @@ TYPE_VOICE_CHANNEL, TYPE_MESSAGEABLE_CHANNEL, ) -from naff.models.naff.protocols import Converter +from naff.models.discord.emoji import PartialEmoji, CustomEmoji +from naff.models.discord.guild import Guild +from naff.models.discord.message import Message +from naff.models.discord.role import Role +from naff.models.discord.snowflake import SnowflakeObject +from naff.models.discord.user import User, Member from naff.models.naff.context import Context -from naff.client.errors import BadArgument - +from naff.models.naff.protocols import Converter __all__ = ( "NoArgumentConverter", diff --git a/naff/models/naff/extension.py b/naff/models/naff/extension.py index 75041b01a..1265ccece 100644 --- a/naff/models/naff/extension.py +++ b/naff/models/naff/extension.py @@ -1,15 +1,17 @@ import asyncio import inspect -from typing import Awaitable, List, TYPE_CHECKING, Callable, Coroutine, Optional +from typing import Awaitable, Dict, List, TYPE_CHECKING, Callable, Coroutine, Optional import naff.models.naff as naff -from naff.client.const import logger, MISSING +from naff.client.const import MISSING from naff.client.utils.misc_utils import wrap_partial +from naff.models.naff import ContextMenu from naff.models.naff.tasks import Task if TYPE_CHECKING: from naff.client import Client - from naff.models.naff import AutoDefer, BaseCommand, Listener + from naff.models.discord import Snowflake_Type + from naff.models.naff import AutoDefer, BaseCommand, InteractionCommand, Listener from naff.models.naff import Context @@ -38,6 +40,7 @@ async def some_command(self, context): extension_checks str: A list of checks to be ran on any command in this extension extension_prerun List: A list of coroutines to be run before any command in this extension extension_postrun List: A list of coroutines to be run after any command in this extension + interaction_tree Dict: A dictionary of registered application commands in a tree """ @@ -49,6 +52,7 @@ async def some_command(self, context): extension_prerun: List extension_postrun: List extension_error: Optional[Callable[..., Coroutine]] + interaction_tree: Dict["Snowflake_Type", Dict[str, "InteractionCommand" | Dict[str, "InteractionCommand"]]] _commands: List _listeners: List auto_defer: "AutoDefer" @@ -61,6 +65,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": new_cls.extension_prerun = [] new_cls.extension_postrun = [] new_cls.extension_error = None + new_cls.interaction_tree = {} new_cls.auto_defer = MISSING new_cls.description = kwargs.get("Description", None) @@ -89,7 +94,27 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": elif isinstance(val, naff.HybridCommand): bot.add_hybrid_command(val) elif isinstance(val, naff.InteractionCommand): - bot.add_interaction(val) + if not bot.add_interaction(val): + continue + base, group, sub, *_ = val.resolved_name.split(" ") + [None, None] + for scope in val.scopes: + if scope not in new_cls.interaction_tree: + new_cls.interaction_tree[scope] = {} + if group is None or isinstance(val, ContextMenu): + new_cls.interaction_tree[scope][val.resolved_name] = val + elif group is not None: + if not (current := new_cls.interaction_tree[scope].get(base)) or isinstance( + current, naff.InteractionCommand + ): + new_cls.interaction_tree[scope][base] = {} + if sub is None: + new_cls.interaction_tree[scope][base][group] = val + else: + if not (current := new_cls.interaction_tree[scope][base].get(group)) or isinstance( + current, naff.InteractionCommand + ): + new_cls.interaction_tree[scope][base][group] = {} + new_cls.interaction_tree[scope][base][group][sub] = val else: bot.add_prefixed_command(val) @@ -100,7 +125,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": elif isinstance(val, Task): wrap_partial(val, new_cls) - logger.debug( + bot.logger.debug( f"{len(new_cls._commands)} commands and {len(new_cls.listeners)} listeners" f" have been loaded from `{new_cls.name}`" ) @@ -197,7 +222,7 @@ def drop(self) -> None: self.bot.listeners[func.event].remove(func) self.bot.ext.pop(self.name, None) - logger.debug(f"{self.name} has been drop") + self.bot.logger.debug(f"{self.name} has been drop") def add_ext_auto_defer(self, ephemeral: bool = False, time_until_defer: float = 0.0) -> None: """ @@ -305,5 +330,5 @@ def __init__(self, bot): raise TypeError("Callback must be a coroutine") if self.extension_error: - logger.warning("Extension error callback has been overridden!") + self.bot.logger.warning("Extension error callback has been overridden!") self.extension_error = coroutine diff --git a/naff/models/naff/hybrid_commands.py b/naff/models/naff/hybrid_commands.py index 3d2a35009..e2d5a7fed 100644 --- a/naff/models/naff/hybrid_commands.py +++ b/naff/models/naff/hybrid_commands.py @@ -1,15 +1,13 @@ -import inspect -import functools import asyncio - +import functools +import inspect from typing import Any, Callable, Coroutine, TYPE_CHECKING, Optional, TypeGuard +import attrs from naff.client.const import Absent, GLOBAL_SCOPE, MISSING, T from naff.client.errors import BadArgument -from naff.client.utils.attr_utils import define, field from naff.client.utils.misc_utils import get_object_name, maybe_coroutine -from naff.models.naff.command import BaseCommand from naff.models.naff.application_commands import ( SlashCommand, LocalisedName, @@ -18,8 +16,8 @@ SlashCommandChoice, OptionTypes, ) -from naff.models.naff.prefixed_commands import _convert_to_bool, PrefixedCommand -from naff.models.naff.protocols import Converter +from naff.models.naff.command import BaseCommand +from naff.models.naff.context import HybridContext, InteractionContext, PrefixedContext from naff.models.naff.converters import ( _LiteralConverter, NoArgumentConverter, @@ -28,7 +26,8 @@ RoleConverter, BaseChannelConverter, ) -from naff.models.naff.context import HybridContext, InteractionContext, PrefixedContext +from naff.models.naff.prefixed_commands import _convert_to_bool, PrefixedCommand +from naff.models.naff.protocols import Converter if TYPE_CHECKING: from naff.models.naff.checks import TYPE_CHECK_FUNCTION @@ -228,7 +227,7 @@ async def convert(self, ctx: HybridContext, _: Any) -> Any: return await maybe_coroutine(self._additional_converter_func, ctx, part_one) -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class HybridCommand(SlashCommand): """A subclass of SlashCommand that handles the logic for hybrid commands.""" @@ -281,9 +280,9 @@ def wrapper(call: Callable[..., Coroutine]) -> "HybridCommand": return wrapper -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class _HybridPrefixedCommand(PrefixedCommand): - _uses_subcommand_func: bool = field(default=False) + _uses_subcommand_func: bool = attrs.field(repr=False, default=False) async def __call__(self, context: PrefixedContext, *args, **kwargs) -> None: new_ctx = context.bot.hybrid_context.from_prefixed_context(context) @@ -353,13 +352,9 @@ def _prefixed_from_slash(cmd: SlashCommand) -> _HybridPrefixedCommand: if ori_param := old_params.pop(str(option.name), None): if ori_param.annotation != inspect._empty and _check_if_annotation(ori_param.annotation, Converter): if option.type != OptionTypes.ATTACHMENT: - annotation = _StackedConverter( - annotation, _get_converter_function(ori_param.annotation, str(option.name)) # type: ignore - ) + annotation = _StackedConverter(annotation, _get_converter_function(ori_param.annotation, str(option.name))) # type: ignore else: - annotation = _StackedNoArgConverter( - _get_converter_function(annotation, ""), _get_converter_function(ori_param.annotation, str(option.name)) # type: ignore - ) + annotation = _StackedNoArgConverter(_get_converter_function(annotation, ""), _get_converter_function(ori_param.annotation, str(option.name))) # type: ignore if not option.required and ori_param.default == inspect._empty: # prefixed commands would automatically fill this in, slash commands don't @@ -394,9 +389,9 @@ def _prefixed_from_slash(cmd: SlashCommand) -> _HybridPrefixedCommand: prefixed_cmd = _HybridPrefixedCommand( name=str(cmd.sub_cmd_name) if cmd.is_subcommand else str(cmd.name), - aliases=list((cmd.sub_cmd_name.to_locale_dict() or {}).values()) + aliases=list(cmd.sub_cmd_name.to_locale_dict().values()) if cmd.is_subcommand - else list((cmd.name.to_locale_dict() or {}).values()), + else list(cmd.name.to_locale_dict().values()), help=str(cmd.sub_cmd_description) if cmd.is_subcommand else str(cmd.description), callback=cmd.callback, extension=cmd.extension, diff --git a/naff/models/naff/listener.py b/naff/models/naff/listener.py index 753f00192..a26950339 100644 --- a/naff/models/naff/listener.py +++ b/naff/models/naff/listener.py @@ -2,10 +2,10 @@ import inspect from typing import Coroutine, Callable -from naff.models.naff.callback import CallbackObject from naff.api.events.internal import BaseEvent from naff.client.const import MISSING, Absent from naff.client.utils import get_event_name +from naff.models.naff.callback import CallbackObject __all__ = ("Listener", "listen") @@ -16,20 +16,51 @@ class Listener(CallbackObject): """Name of the event to listen to.""" callback: Coroutine """Coroutine to call when the event is triggered.""" - - def __init__(self, func: Callable[..., Coroutine], event: str) -> None: + is_default_listener: bool + """Whether this listener is provided automatically by the library, and might be unwanted by users.""" + disable_default_listeners: bool + """Whether this listener supersedes default listeners. If true, any default listeners will be unregistered.""" + delay_until_ready: bool + """whether to delay the event until the client is ready""" + + def __init__( + self, + func: Callable[..., Coroutine], + event: str, + *, + delay_until_ready: bool = False, + is_default_listener: bool = False, + disable_default_listeners: bool = False, + ) -> None: super().__init__() + if is_default_listener: + disable_default_listeners = False + self.event = event self.callback = func + self.delay_until_ready = delay_until_ready + self.is_default_listener = is_default_listener + self.disable_default_listeners = disable_default_listeners @classmethod - def create(cls, event_name: Absent[str | BaseEvent] = MISSING) -> Callable[[Coroutine], "Listener"]: + def create( + cls, + event_name: Absent[str | BaseEvent] = MISSING, + *, + delay_until_ready: bool = False, + is_default_listener: bool = False, + disable_default_listeners: bool = False, + ) -> Callable[[Coroutine], "Listener"]: """ Decorator for creating an event listener. Args: event_name: The name of the event to listen to. If left blank, event name will be inferred from the function name or parameter. + delay_until_ready: Whether to delay the listener until the client is ready. + is_default_listener: Whether this listener is provided automatically by the library, and might be unwanted by users. + disable_default_listeners: Whether this listener supersedes default listeners. If true, any default listeners will be unregistered. + Returns: A listener object. @@ -55,20 +86,41 @@ def wrapper(coro: Coroutine) -> "Listener": if not name: name = coro.__name__ - return cls(coro, get_event_name(name)) + return cls( + coro, + get_event_name(name), + delay_until_ready=delay_until_ready, + is_default_listener=is_default_listener, + disable_default_listeners=disable_default_listeners, + ) return wrapper -def listen(event_name: Absent[str | BaseEvent] = MISSING) -> Callable[[Callable[..., Coroutine]], Listener]: +def listen( + event_name: Absent[str | BaseEvent] = MISSING, + *, + delay_until_ready: bool = True, + is_default_listener: bool = False, + disable_default_listeners: bool = False, +) -> Callable[[Callable[..., Coroutine]], Listener]: """ Decorator to make a function an event listener. Args: event_name: The name of the event to listen to. If left blank, event name will be inferred from the function name or parameter. + delay_until_ready: Whether to delay the listener until the client is ready. + is_default_listener: Whether this listener is provided automatically by the library, and might be unwanted by users. + disable_default_listeners: Whether this listener supersedes default listeners. If true, any default listeners will be unregistered. + Returns: A listener object. """ - return Listener.create(event_name) + return Listener.create( + event_name, + delay_until_ready=delay_until_ready, + is_default_listener=is_default_listener, + disable_default_listeners=disable_default_listeners, + ) diff --git a/naff/models/naff/localisation.py b/naff/models/naff/localisation.py index f6d0e7d98..fde048160 100644 --- a/naff/models/naff/localisation.py +++ b/naff/models/naff/localisation.py @@ -1,12 +1,13 @@ from functools import cached_property +import attrs + from naff.client import const -from naff.client.utils import define, field __all__ = ("LocalisedField", "LocalizedField") -@define(slots=False) +@attrs.define(eq=False, order=False, hash=False, slots=False) class LocalisedField: """ An object that enables support for localising fields. @@ -14,38 +15,38 @@ class LocalisedField: Supported locales: https://discord.com/developers/docs/reference#locales """ - default_locale: str = field(default=const.default_locale) - - bulgarian: str | None = field(default=None, metadata={"locale-code": "bg"}) - chinese_china: str | None = field(default=None, metadata={"locale-code": "zh-CN"}) - chinese_taiwan: str | None = field(default=None, metadata={"locale-code": "zh-TW"}) - croatian: str | None = field(default=None, metadata={"locale-code": "hr"}) - czech: str | None = field(default=None, metadata={"locale-code": "cs"}) - danish: str | None = field(default=None, metadata={"locale-code": "da"}) - dutch: str | None = field(default=None, metadata={"locale-code": "nl"}) - english_uk: str | None = field(default=None, metadata={"locale-code": "en-GB"}) - english_us: str | None = field(default=None, metadata={"locale-code": "en-US"}) - finnish: str | None = field(default=None, metadata={"locale-code": "fi"}) - french: str | None = field(default=None, metadata={"locale-code": "fr"}) - german: str | None = field(default=None, metadata={"locale-code": "de"}) - greek: str | None = field(default=None, metadata={"locale-code": "el"}) - hindi: str | None = field(default=None, metadata={"locale-code": "hi"}) - hungarian: str | None = field(default=None, metadata={"locale-code": "hu"}) - italian: str | None = field(default=None, metadata={"locale-code": "it"}) - japanese: str | None = field(default=None, metadata={"locale-code": "ja"}) - korean: str | None = field(default=None, metadata={"locale-code": "ko"}) - lithuanian: str | None = field(default=None, metadata={"locale-code": "lt"}) - norwegian: str | None = field(default=None, metadata={"locale-code": "no"}) - polish: str | None = field(default=None, metadata={"locale-code": "pl"}) - portuguese_brazilian: str | None = field(default=None, metadata={"locale-code": "pt-BR"}) - romanian_romania: str | None = field(default=None, metadata={"locale-code": "ro"}) - russian: str | None = field(default=None, metadata={"locale-code": "ru"}) - spanish: str | None = field(default=None, metadata={"locale-code": "es-ES"}) - swedish: str | None = field(default=None, metadata={"locale-code": "sv-SE"}) - thai: str | None = field(default=None, metadata={"locale-code": "th"}) - turkish: str | None = field(default=None, metadata={"locale-code": "tr"}) - ukrainian: str | None = field(default=None, metadata={"locale-code": "uk"}) - vietnamese: str | None = field(default=None, metadata={"locale-code": "vi"}) + default_locale: str = attrs.field(repr=False, default=const.default_locale) + + bulgarian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "bg"}) + chinese_china: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "zh-CN"}) + chinese_taiwan: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "zh-TW"}) + croatian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "hr"}) + czech: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "cs"}) + danish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "da"}) + dutch: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "nl"}) + english_uk: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "en-GB"}) + english_us: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "en-US"}) + finnish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "fi"}) + french: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "fr"}) + german: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "de"}) + greek: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "el"}) + hindi: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "hi"}) + hungarian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "hu"}) + italian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "it"}) + japanese: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "ja"}) + korean: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "ko"}) + lithuanian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "lt"}) + norwegian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "no"}) + polish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "pl"}) + portuguese_brazilian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "pt-BR"}) + romanian_romania: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "ro"}) + russian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "ru"}) + spanish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "es-ES"}) + swedish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "sv-SE"}) + thai: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "th"}) + turkish: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "tr"}) + ukrainian: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "uk"}) + vietnamese: str | None = attrs.field(repr=False, default=None, metadata={"locale-code": "vi"}) def __str__(self) -> str: return str(self.default) diff --git a/naff/models/naff/prefixed_commands.py b/naff/models/naff/prefixed_commands.py index 5a3f72a17..d97f34cf7 100644 --- a/naff/models/naff/prefixed_commands.py +++ b/naff/models/naff/prefixed_commands.py @@ -9,12 +9,12 @@ from naff.client.const import MISSING from naff.client.errors import BadArgument +from naff.client.utils.attr_utils import docs from naff.client.utils.input_utils import _quotes -from naff.client.utils.attr_utils import define, field, docs from naff.client.utils.misc_utils import get_object_name, maybe_coroutine -from naff.models.naff.protocols import Converter -from naff.models.naff.converters import _LiteralConverter, NoArgumentConverter, Greedy, NAFF_MODEL_TO_CONVERTER from naff.models.naff.command import BaseCommand +from naff.models.naff.converters import _LiteralConverter, NoArgumentConverter, Greedy, NAFF_MODEL_TO_CONVERTER +from naff.models.naff.protocols import Converter if TYPE_CHECKING: from naff.models.naff.context import PrefixedContext @@ -28,7 +28,7 @@ _STARTING_QUOTES = frozenset(_quotes.keys()) -@attrs.define(slots=True) +@attrs.define(eq=False, order=False, hash=False, slots=True) class PrefixedCommandParameter: """ An object representing parameters in a prefixed command. @@ -62,7 +62,7 @@ def optional(self) -> bool: return self.default != MISSING -@attrs.define(slots=True) +@attrs.define(eq=False, order=False, hash=False, slots=True) class _PrefixedArgsIterator: """ An iterator over the arguments of a prefixed command. @@ -257,40 +257,44 @@ async def _greedy_convert( return greedy_args, broke_off -@define() +@attrs.define(eq=False, order=False, hash=False, kw_only=True) class PrefixedCommand(BaseCommand): - name: str = field(metadata=docs("The name of the command.")) - parameters: list[PrefixedCommandParameter] = field(metadata=docs("The parameters of the command."), factory=list) - aliases: list[str] = field( + name: str = attrs.field(repr=False, metadata=docs("The name of the command.")) + parameters: list[PrefixedCommandParameter] = attrs.field( + repr=False, metadata=docs("The parameters of the command."), factory=list + ) + aliases: list[str] = attrs.field( metadata=docs("The list of aliases the command can be invoked under."), factory=list, ) - hidden: bool = field( + hidden: bool = attrs.field( metadata=docs("If `True`, help commands should not show this in the help output (unless toggled to do so)."), default=False, ) - ignore_extra: bool = field( + ignore_extra: bool = attrs.field( metadata=docs( "If `True`, ignores extraneous strings passed to a command if all its requirements are met (e.g. ?foo a b c" " when only expecting a and b). Otherwise, an error is raised. Defaults to True." ), default=True, ) - hierarchical_checking: bool = field( + hierarchical_checking: bool = attrs.field( metadata=docs( "If `True` and if the base of a subcommand, every subcommand underneath it will run this command's checks" " and cooldowns before its own. Otherwise, only the subcommand's checks are checked." ), default=True, ) - help: Optional[str] = field(metadata=docs("The long help text for the command."), default=None) - brief: Optional[str] = field(metadata=docs("The short help text for the command."), default=None) - parent: Optional["PrefixedCommand"] = field(metadata=docs("The parent command, if applicable."), default=None) - subcommands: dict[str, "PrefixedCommand"] = field( - metadata=docs("A dict of all subcommands for the command."), factory=dict + help: Optional[str] = attrs.field(repr=False, metadata=docs("The long help text for the command."), default=None) + brief: Optional[str] = attrs.field(repr=False, metadata=docs("The short help text for the command."), default=None) + parent: Optional["PrefixedCommand"] = attrs.field( + repr=False, metadata=docs("The parent command, if applicable."), default=None + ) + subcommands: dict[str, "PrefixedCommand"] = attrs.field( + repr=False, metadata=docs("A dict of all subcommands for the command."), factory=dict ) - _usage: Optional[str] = field(default=None) - _inspect_signature: Optional[inspect.Signature] = field(default=None) + _usage: Optional[str] = attrs.field(repr=False, default=None) + _inspect_signature: Optional[inspect.Signature] = attrs.field(repr=False, default=None) def __attrs_post_init__(self) -> None: super().__attrs_post_init__() # we want checks to work diff --git a/naff/models/naff/protocols.py b/naff/models/naff/protocols.py index 639d520b0..cee6c632f 100644 --- a/naff/models/naff/protocols.py +++ b/naff/models/naff/protocols.py @@ -1,11 +1,10 @@ import typing from typing import Protocol, Any, TYPE_CHECKING -from naff.client.const import T_co from naff.api.http.route import Route +from naff.client.const import T_co from naff.models.discord.file import UPLOADABLE_TYPE - if TYPE_CHECKING: from naff.models.naff.context import Context diff --git a/naff/models/naff/tasks/task.py b/naff/models/naff/tasks/task.py index de9ca4f5a..1f745d883 100644 --- a/naff/models/naff/tasks/task.py +++ b/naff/models/naff/tasks/task.py @@ -5,11 +5,9 @@ from typing import Callable import naff -from naff.client.const import logger - +from naff.client.const import get_logger from .triggers import BaseTrigger - __all__ = ("Task",) @@ -120,7 +118,7 @@ def start(self) -> None: self._stop.clear() self.task = asyncio.create_task(self._task_loop()) except RuntimeError: - logger.error( + get_logger().error( "Unable to start task without a running event loop! We recommend starting tasks within an `on_startup` event." ) diff --git a/poetry.lock b/poetry.lock index 7f93f408b..8515339e9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -16,7 +16,7 @@ multidict = ">=4.5,<7.0" yarl = ">=1.0,<2.0" [package.extras] -speedups = ["cchardet", "brotli", "aiodns"] +speedups = ["Brotli", "aiodns", "cchardet"] [[package]] name = "aiosignal" @@ -29,6 +29,14 @@ python-versions = ">=3.6" [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "ansicon" +version = "1.89.0" +description = "Python wrapper for loading Jason Hood's ANSICON" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "async-timeout" version = "4.0.2" @@ -46,10 +54,10 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [package.extras] -tests_no_zope = ["cloudpickle", "pytest-mypy-plugins", "mypy", "six", "pytest (>=4.3.0)", "pympler", "hypothesis", "coverage[toml] (>=5.0.2)"] -tests = ["cloudpickle", "zope.interface", "pytest-mypy-plugins", "mypy", "six", "pytest (>=4.3.0)", "pympler", "hypothesis", "coverage[toml] (>=5.0.2)"] -docs = ["sphinx-notfound-page", "zope.interface", "sphinx", "furo"] -dev = ["cloudpickle", "pre-commit", "sphinx-notfound-page", "sphinx", "furo", "zope.interface", "pytest-mypy-plugins", "mypy", "six", "pytest (>=4.3.0)", "pympler", "hypothesis", "coverage[toml] (>=5.0.2)"] +dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "sphinx", "sphinx-notfound-page", "zope.interface"] +docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] +tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "zope.interface"] +tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six"] [[package]] name = "black" @@ -67,10 +75,23 @@ platformdirs = ">=2" tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} [package.extras] -uvloop = ["uvloop (>=0.15.2)"] -jupyter = ["tokenize-rt (>=3.2.0)", "ipython (>=7.8.0)"] -d = ["aiohttp (>=3.7.4)"] colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + +[[package]] +name = "blessed" +version = "1.19.1" +description = "Easy, practical library for making terminal apps, by providing an elegant, well-documented interface to Colors, Keyboard input, and screen Positioning capabilities." +category = "main" +optional = false +python-versions = ">=2.7" + +[package.dependencies] +jinxed = {version = ">=1.1.0", markers = "platform_system == \"Windows\""} +six = ">=1.9.0" +wcwidth = ">=0.1.4" [[package]] name = "cfgv" @@ -102,6 +123,14 @@ python-versions = ">=3.7" [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "codefind" +version = "0.1.3" +description = "Find code objects and their referents" +category = "main" +optional = false +python-versions = ">=3.8,<4.0" + [[package]] name = "colorama" version = "0.4.5" @@ -138,8 +167,8 @@ optional = false python-versions = ">=3.7" [package.extras] -testing = ["pytest-timeout (>=2.1)", "pytest-cov (>=3)", "pytest (>=7.1.2)", "coverage (>=6.4.2)", "covdefaults (>=2.2)"] -docs = ["sphinx-autodoc-typehints (>=1.19.1)", "sphinx (>=5.1.1)", "furo (>=2022.6.21)"] +docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"] +testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"] [[package]] name = "frozenlist" @@ -168,6 +197,34 @@ category = "main" optional = false python-versions = ">=3.5" +[[package]] +name = "jinxed" +version = "1.2.0" +description = "Jinxed Terminal Library" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +ansicon = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "jurigged" +version = "0.5.3" +description = "Live update of Python functions" +category = "main" +optional = false +python-versions = ">=3.8,<4.0" + +[package.dependencies] +blessed = ">=1.17.12,<2.0.0" +codefind = ">=0.1.3,<0.2.0" +ovld = ">=0.3.1,<0.4.0" +watchdog = ">=1.0.2" + +[package.extras] +develoop = ["giving (>=0.3.6,<0.4.0)", "hrepr (>=0.4.0,<0.5.0)", "rich (>=10.13.0,<11.0.0)"] + [[package]] name = "multidict" version = "6.0.2" @@ -190,9 +247,9 @@ tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = ">=3.10" [package.extras] -reports = ["lxml"] -python2 = ["typed-ast (>=1.4.0,<2)"] dmypy = ["psutil (>=4.0)"] +python2 = ["typed-ast (>=1.4.0,<2)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" @@ -210,6 +267,9 @@ category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" +[package.dependencies] +setuptools = "*" + [[package]] name = "orjson" version = "3.7.12" @@ -218,6 +278,14 @@ category = "main" optional = true python-versions = ">=3.7" +[[package]] +name = "ovld" +version = "0.3.2" +description = "Overloading Python functions" +category = "main" +optional = false +python-versions = ">=3.6,<4.0" + [[package]] name = "pathspec" version = "0.9.0" @@ -235,8 +303,8 @@ optional = false python-versions = ">=3.7" [package.extras] -test = ["pytest (>=6)", "pytest-mock (>=3.6)", "pytest-cov (>=2.7)", "appdirs (==1.4.4)"] -docs = ["sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)", "proselint (>=0.10.2)", "furo (>=2021.7.5b38)"] +docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"] +test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] [[package]] name = "pre-commit" @@ -262,6 +330,27 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "setuptools" +version = "65.3.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" + [[package]] name = "toml" version = "0.10.2" @@ -300,8 +389,27 @@ filelock = ">=3.4.1,<4" platformdirs = ">=2.4,<3" [package.extras] -testing = ["pytest-timeout (>=2.1)", "pytest-randomly (>=3.10.3)", "pytest-mock (>=3.6.1)", "pytest-freezegun (>=0.4.2)", "pytest-env (>=0.6.2)", "pytest (>=7.0.1)", "packaging (>=21.3)", "flaky (>=3.7)", "coverage-enable-subprocess (>=1)", "coverage (>=6.2)"] -docs = ["towncrier (>=21.9)", "sphinx-rtd-theme (>=1)", "sphinx-argparse (>=0.3.1)", "sphinx (>=5.1.1)", "proselint (>=0.13)"] +docs = ["proselint (>=0.13)", "sphinx (>=5.1.1)", "sphinx-argparse (>=0.3.1)", "sphinx-rtd-theme (>=1)", "towncrier (>=21.9)"] +testing = ["coverage (>=6.2)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=21.3)", "pytest (>=7.0.1)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.2)", "pytest-mock (>=3.6.1)", "pytest-randomly (>=3.10.3)", "pytest-timeout (>=2.1)"] + +[[package]] +name = "watchdog" +version = "2.1.9" +description = "Filesystem events monitoring" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +watchmedo = ["PyYAML (>=3.10)"] + +[[package]] +name = "wcwidth" +version = "0.2.5" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" +optional = false +python-versions = "*" [[package]] name = "yarl" @@ -321,7 +429,7 @@ orjson = ["orjson"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "05642a4dd000c0ddab4ec77d8c5c941f305a5c0c6fd23f7887a2e19bfeb61af6" +content-hash = "3014d49f86b05fbc78d76d11b8e42cd5f779b7e2ffe6ded9638681fcc7d4fb1b" [metadata.files] aiohttp = [ @@ -402,6 +510,10 @@ aiosignal = [ {file = "aiosignal-1.2.0-py3-none-any.whl", hash = "sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a"}, {file = "aiosignal-1.2.0.tar.gz", hash = "sha256:78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2"}, ] +ansicon = [ + {file = "ansicon-1.89.0-py2.py3-none-any.whl", hash = "sha256:f1def52d17f65c2c9682cf8370c03f541f410c1752d6a14029f97318e4b9dfec"}, + {file = "ansicon-1.89.0.tar.gz", hash = "sha256:e4d039def5768a47e4afec8e89e83ec3ae5a26bf00ad851f914d1240b444d2b1"}, +] async-timeout = [ {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, @@ -435,6 +547,10 @@ black = [ {file = "black-22.6.0-py3-none-any.whl", hash = "sha256:ac609cf8ef5e7115ddd07d85d988d074ed00e10fbc3445aee393e70164a2219c"}, {file = "black-22.6.0.tar.gz", hash = "sha256:6c6d39e28aed379aec40da1c65434c77d75e65bb59a1e1c283de545fb4e7c6c9"}, ] +blessed = [ + {file = "blessed-1.19.1-py2.py3-none-any.whl", hash = "sha256:63b8554ae2e0e7f43749b6715c734cc8f3883010a809bf16790102563e6cf25b"}, + {file = "blessed-1.19.1.tar.gz", hash = "sha256:9a0d099695bf621d4680dd6c73f6ad547f6a3442fbdbe80c4b1daa1edbc492fc"}, +] cfgv = [ {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"}, {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"}, @@ -447,6 +563,10 @@ click = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] +codefind = [ + {file = "codefind-0.1.3-py3-none-any.whl", hash = "sha256:3ffe85b74595b5c9f82391a11171ce7d68f1f555485720ab922f3b86f9bf30ec"}, + {file = "codefind-0.1.3.tar.gz", hash = "sha256:5667050361bf601a253031b2437d16b7d82cb0fa0e756d93e548c7b35ce6f910"}, +] colorama = [ {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, @@ -532,6 +652,14 @@ idna = [ {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, ] +jinxed = [ + {file = "jinxed-1.2.0-py2.py3-none-any.whl", hash = "sha256:cfc2b2e4e3b4326954d546ba6d6b9a7a796ddcb0aef8d03161d005177eb0d48b"}, + {file = "jinxed-1.2.0.tar.gz", hash = "sha256:032acda92d5c57cd216033cbbd53de731e6ed50deb63eb4781336ca55f72cda5"}, +] +jurigged = [ + {file = "jurigged-0.5.3-py3-none-any.whl", hash = "sha256:355a9bddf42cae541e862796fb125827fc35573a982c6f35d3dc5621e59c91e3"}, + {file = "jurigged-0.5.3.tar.gz", hash = "sha256:47cf4e9f10455a39602caa447888c06adda962699c65f19d8c37509817341b5e"}, +] multidict = [ {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"}, {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"}, @@ -670,6 +798,10 @@ orjson = [ {file = "orjson-3.7.12-cp39-none-win_amd64.whl", hash = "sha256:c1e4297b5dee3e14e068cc35505b3e1a626dd3fb8d357842902616564d2f713f"}, {file = "orjson-3.7.12.tar.gz", hash = "sha256:05f20fa1a368207d16ecdf16072c3be58f85c4954cd2ed6c9704463963b9791a"}, ] +ovld = [ + {file = "ovld-0.3.2-py3-none-any.whl", hash = "sha256:3a5f08f66573198b490fc69dcf93a2ad9b4d90fd1fef885cf7a8dbe565f17837"}, + {file = "ovld-0.3.2.tar.gz", hash = "sha256:f8918636c240a2935175406801944d4314823710b3afbd5a8db3e79cd9391c42"}, +] pathspec = [ {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, @@ -690,6 +822,13 @@ pyyaml = [ {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, @@ -717,6 +856,14 @@ pyyaml = [ {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] +setuptools = [ + {file = "setuptools-65.3.0-py3-none-any.whl", hash = "sha256:2e24e0bec025f035a2e72cdd1961119f557d78ad331bb00ff82efb2ab8da8e82"}, + {file = "setuptools-65.3.0.tar.gz", hash = "sha256:7732871f4f7fa58fb6bdcaeadb0161b2bd046c85905dbaa066bdcbcc81953b57"}, +] +six = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, @@ -733,6 +880,37 @@ virtualenv = [ {file = "virtualenv-20.16.3-py2.py3-none-any.whl", hash = "sha256:4193b7bc8a6cd23e4eb251ac64f29b4398ab2c233531e66e40b19a6b7b0d30c1"}, {file = "virtualenv-20.16.3.tar.gz", hash = "sha256:d86ea0bb50e06252d79e6c241507cb904fcd66090c3271381372d6221a3970f9"}, ] +watchdog = [ + {file = "watchdog-2.1.9-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a735a990a1095f75ca4f36ea2ef2752c99e6ee997c46b0de507ba40a09bf7330"}, + {file = "watchdog-2.1.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b17d302850c8d412784d9246cfe8d7e3af6bcd45f958abb2d08a6f8bedf695d"}, + {file = "watchdog-2.1.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ee3e38a6cc050a8830089f79cbec8a3878ec2fe5160cdb2dc8ccb6def8552658"}, + {file = "watchdog-2.1.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:64a27aed691408a6abd83394b38503e8176f69031ca25d64131d8d640a307591"}, + {file = "watchdog-2.1.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:195fc70c6e41237362ba720e9aaf394f8178bfc7fa68207f112d108edef1af33"}, + {file = "watchdog-2.1.9-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:bfc4d351e6348d6ec51df007432e6fe80adb53fd41183716017026af03427846"}, + {file = "watchdog-2.1.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8250546a98388cbc00c3ee3cc5cf96799b5a595270dfcfa855491a64b86ef8c3"}, + {file = "watchdog-2.1.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:117ffc6ec261639a0209a3252546b12800670d4bf5f84fbd355957a0595fe654"}, + {file = "watchdog-2.1.9-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:97f9752208f5154e9e7b76acc8c4f5a58801b338de2af14e7e181ee3b28a5d39"}, + {file = "watchdog-2.1.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:247dcf1df956daa24828bfea5a138d0e7a7c98b1a47cf1fa5b0c3c16241fcbb7"}, + {file = "watchdog-2.1.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:226b3c6c468ce72051a4c15a4cc2ef317c32590d82ba0b330403cafd98a62cfd"}, + {file = "watchdog-2.1.9-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d9820fe47c20c13e3c9dd544d3706a2a26c02b2b43c993b62fcd8011bcc0adb3"}, + {file = "watchdog-2.1.9-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:70af927aa1613ded6a68089a9262a009fbdf819f46d09c1a908d4b36e1ba2b2d"}, + {file = "watchdog-2.1.9-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ed80a1628cee19f5cfc6bb74e173f1b4189eb532e705e2a13e3250312a62e0c9"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9f05a5f7c12452f6a27203f76779ae3f46fa30f1dd833037ea8cbc2887c60213"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_armv7l.whl", hash = "sha256:255bb5758f7e89b1a13c05a5bceccec2219f8995a3a4c4d6968fe1de6a3b2892"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_i686.whl", hash = "sha256:d3dda00aca282b26194bdd0adec21e4c21e916956d972369359ba63ade616153"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_ppc64.whl", hash = "sha256:186f6c55abc5e03872ae14c2f294a153ec7292f807af99f57611acc8caa75306"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:083171652584e1b8829581f965b9b7723ca5f9a2cd7e20271edf264cfd7c1412"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_s390x.whl", hash = "sha256:b530ae007a5f5d50b7fbba96634c7ee21abec70dc3e7f0233339c81943848dc1"}, + {file = "watchdog-2.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:4f4e1c4aa54fb86316a62a87b3378c025e228178d55481d30d857c6c438897d6"}, + {file = "watchdog-2.1.9-py3-none-win32.whl", hash = "sha256:5952135968519e2447a01875a6f5fc8c03190b24d14ee52b0f4b1682259520b1"}, + {file = "watchdog-2.1.9-py3-none-win_amd64.whl", hash = "sha256:7a833211f49143c3d336729b0020ffd1274078e94b0ae42e22f596999f50279c"}, + {file = "watchdog-2.1.9-py3-none-win_ia64.whl", hash = "sha256:ad576a565260d8f99d97f2e64b0f97a48228317095908568a9d5c786c829d428"}, + {file = "watchdog-2.1.9.tar.gz", hash = "sha256:43ce20ebb36a51f21fa376f76d1d4692452b2527ccd601950d69ed36b9e21609"}, +] +wcwidth = [ + {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, + {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, +] yarl = [ {file = "yarl-1.8.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:abc06b97407868ef38f3d172762f4069323de52f2b70d133d096a48d72215d28"}, {file = "yarl-1.8.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:07b21e274de4c637f3e3b7104694e53260b5fc10d51fb3ec5fed1da8e0f754e3"}, diff --git a/pyproject.toml b/pyproject.toml index 0f971024e..4fbdcef29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,27 +1,35 @@ [tool.poetry] name = "naff" -version = "1.12.0" +version = "2.0.0" description = "Not another freaking fork" authors = [ "LordOfPolls ", ] [tool.poetry.dependencies] -python = "^3.10" -aiohttp = "^3.7.4" +python = ">=3.10" +aiohttp = "^3.8.3" attrs = "^21.4.0" mypy = ">0.930" discord-typings = "^0.5.1" tomli = "^2.0.1" +emoji = "^2.1.0" [tool.poetry.dependencies.orjson] version = "^3.6.8" optional = true +[tool.poetry.dependencies.jurigged] +version = "^0.5.3" +optional = true + [tool.poetry.extras] orjson = [ "orjson", ] +jurigged = [ + "jurigged", +] [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/pytest.ini b/pytest.ini index a96bd98ca..7efc3d959 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,7 @@ addopts = -l -ra --durations=2 --cov=./ --cov-report xml:coverage.xml --junitxml=TestResults.xml doctest_optionflags = NORMALIZE_WHITESPACE asyncio_mode=auto +log_cli = 1 +log_cli_level = DEBUG +log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) +log_cli_date_format=%Y-%m-%d %H:%M:%S diff --git a/requirements.txt b/requirements.txt index f643bed28..fedbc3642 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ aiohttp attrs discord-typings>=0.5.1 +emoji tomli diff --git a/setup.py b/setup.py index 84915242a..50bd22636 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "voice": ["PyNaCl>=1.5.0,<1.6"], "speedup": ["aiodns", "orjson", "Brotli"], "sentry": ["sentry-sdk"], + "jurigged": ["jurigged"], } extras_require["all"] = list(itertools.chain.from_iterable(extras_require.values())) extras_require["docs"] = extras_require["all"] + [ diff --git a/tests/test_bot.py b/tests/test_bot.py index c6650659e..fd398815c 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from asyncio import AbstractEventLoop from contextlib import suppress @@ -37,8 +38,6 @@ from naff.api.http.route import Route from naff.api.voice.audio import AudioVolume from naff.client.errors import NotFound -from naff.client.utils.misc_utils import find -from naff.models.discord.timestamp import Timestamp __all__ = () @@ -54,6 +53,10 @@ TOKEN = os.environ.get("BOT_TOKEN") if not TOKEN: pytest.skip(f"Skipping {os.path.basename(__file__)} - no token provided", allow_module_level=True) +if os.environ.get("GITHUB_ACTIONS") and not os.environ.get("RUN_TESTBOT"): + pytest.skip(f"Skipping {os.path.basename(__file__)} - RUN_TESTBOT not set", allow_module_level=True) + +log = logging.getLogger("NAFF-Integration-Tests") @pytest.fixture(scope="module") @@ -62,46 +65,51 @@ def event_loop() -> AbstractEventLoop: @pytest.fixture(scope="module") -async def bot() -> Client: +async def bot(github_commit) -> Client: bot = naff.Client(activity="Testing someones code") await bot.login(TOKEN) asyncio.create_task(bot.start_gateway()) await bot._ready.wait() + bot.suffix = github_commit + log.info(f"Logged in as {bot.user} ({bot.user.id}) -- {bot.suffix}") yield bot @pytest.fixture(scope="module") async def guild(bot: Client) -> Guild: - if len(bot.guilds) > 9: - leftover = find(lambda g: g.is_owner(bot.user.id) and g.name == "test_suite_guild", bot.guilds) - if leftover: - age = Timestamp.now() - leftover.created_at - if age.days > 0: - # This from a failed run, let's clean it up - await leftover.delete() - guild: naff.Guild = await naff.Guild.create("test_suite_guild", bot) - community_channel = await guild.create_text_channel("community_channel") - - await guild.edit( - features=["COMMUNITY"], - rules_channel=community_channel, - system_channel=community_channel, - public_updates_channel=community_channel, - explicit_content_filter=ExplicitContentFilterLevels.ALL_MEMBERS, - verification_level=VerificationLevels.LOW, - ) + guild = next((g for g in bot.guilds if g.name == "NAFF Test Suite"), None) + if not guild: + log.info("No guild found, creating one...") + + guild: naff.Guild = await naff.Guild.create("NAFF Test Suite", bot) + community_channel = await guild.create_text_channel("community_channel") + + await guild.edit( + features=["COMMUNITY"], + rules_channel=community_channel, + system_channel=community_channel, + public_updates_channel=community_channel, + explicit_content_filter=ExplicitContentFilterLevels.ALL_MEMBERS, + verification_level=VerificationLevels.LOW, + ) yield guild - await guild.delete() - @pytest.fixture(scope="module") async def channel(bot, guild) -> GuildText: - channel = await guild.create_text_channel("test_scene") - return channel + channel = await guild.create_text_channel(f"test_scene - {bot.suffix}") + yield channel + await channel.delete() + + +@pytest.fixture(scope="module") +async def github_commit() -> str: + import subprocess # noqa: S404 + + return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() # noqa: S603, S607 def ensure_attributes(target_object) -> None: @@ -114,10 +122,10 @@ def ensure_attributes(target_object) -> None: async def test_channels(bot: Client, guild: Guild) -> None: channels = [ guild_category := await guild.create_category("_test_category"), - await guild.create_text_channel("_test_text"), - await guild.create_news_channel("_test_news"), - await guild.create_stage_channel("_test_stage"), - await guild.create_voice_channel("_test_voice"), + await guild.create_text_channel(f"_test_text-{bot.suffix}"), + await guild.create_news_channel(f"_test_news-{bot.suffix}"), + await guild.create_stage_channel(f"_test_stage-{bot.suffix}"), + await guild.create_voice_channel(f"_test_voice-{bot.suffix}"), ] assert all(c in guild.channels for c in channels) @@ -133,7 +141,6 @@ async def test_channels(bot: Client, guild: Guild) -> None: assert channel.category == guild_category if isinstance(channel, MessageableMixin) and not isinstance(channel, GuildVoice): - # todo: remove the guild voice exception when text-in-voice releases _m = await channel.send("test") assert _m.channel == channel @@ -179,7 +186,7 @@ async def test_messages(bot: Client, guild: Guild, channel: GuildText) -> None: _m = await thread.send("Test") ensure_attributes(_m) - await _m.edit("Test Edit") + await _m.edit(content="Test Edit") assert _m.content == "Test Edit" await _m.add_reaction("❌") with suppress(asyncio.exceptions.TimeoutError): @@ -253,7 +260,7 @@ async def test_roles(bot: Client, guild: Guild) -> None: await guild.me.add_role(roles[0]) await guild.me.remove_role(roles[0]) - await roles[0].edit("_test_renamed", color=BrandColors.RED) + await roles[0].edit(name="_test_renamed", color=BrandColors.RED) for role in roles: await role.delete() @@ -280,20 +287,29 @@ async def test_members(bot: Client, guild: Guild, channel: GuildText) -> None: await bot.wait_for("member_update", timeout=2) assert member.display_name == (bot.get_user(member.id)).username - assert len(member.roles) == 0 - role = await guild.create_role("test") - await member.add_role(role) - with suppress(asyncio.exceptions.TimeoutError): - await bot.wait_for("member_update", timeout=2) - assert len(member.roles) != 0 + base_line = len(member.roles) - assert member.display_avatar is not None - assert member.display_name is not None + assert len(member.roles) == base_line + role = await guild.create_role(f"test-{bot.suffix}") + try: + await member.add_role(role) + with suppress(asyncio.exceptions.TimeoutError): + await bot.wait_for("member_update", timeout=2) + assert len(member.roles) != base_line + await member.remove_role(role) + with suppress(asyncio.exceptions.TimeoutError): + await bot.wait_for("member_update", timeout=2) + assert len(member.roles) == base_line + + assert member.display_avatar is not None + assert member.display_name is not None - assert member.has_permission(Permissions.SEND_MESSAGES) - assert member.channel_permissions(channel) + assert member.has_permission(Permissions.SEND_MESSAGES) + assert member.channel_permissions(channel) - assert member.guild_permissions is not None + assert member.guild_permissions is not None + finally: + await role.delete() @pytest.mark.asyncio @@ -399,8 +415,8 @@ async def test_components(bot: Client, channel: GuildText) -> None: components=naff.ActionRow(*[naff.Button(1, "test"), naff.Button(1, "test")]), ) await thread.send( - "Test - Select", - components=naff.Select([SelectOption("test", "test")]), + "Test - StringSelectMenu", + components=naff.StringSelectMenu([SelectOption("test", "test")]), ) Modal("Test Modal", [ParagraphText("test", value="test value, press send")]) @@ -413,77 +429,76 @@ async def test_components(bot: Client, channel: GuildText) -> None: @pytest.mark.asyncio -async def test_webhooks(bot: Client, guild: Guild) -> None: - test_channel = await guild.create_text_channel("_test_webhooks") - test_thread = await test_channel.create_thread("Test Thread") +async def test_webhooks(bot: Client, guild: Guild, channel: GuildText) -> None: + test_thread = await channel.create_thread("Test Thread") - try: - hook = await test_channel.create_webhook("Test") - await hook.send("Test 123") - await hook.delete() - - hook = await test_channel.create_webhook("Test-Avatar", r"tests/LordOfPolls.png") - - _m = await hook.send("Test", wait=True) - assert isinstance(_m, Message) - assert _m.webhook_id == hook.id - await hook.send("Test", username="Different Name", wait=True) - await hook.send("Test", avatar_url=bot.user.avatar.url, wait=True) - _m = await hook.send("Test", thread=test_thread, wait=True) - assert _m is not None - assert _m.channel == test_thread - - await hook.delete() - finally: - await test_channel.delete() + hook = await channel.create_webhook("Test") + await hook.send("Test 123") + await hook.delete() + hook = await channel.create_webhook("Test-Avatar", r"tests/LordOfPolls.png") -@pytest.mark.asyncio -async def test_voice(bot: Client, guild: Guild) -> None: - try: - import nacl # noqa - except ImportError: - # testing on a non-voice extra - return - test_channel = await guild.create_voice_channel("_test_voice") - test_channel_two = await guild.create_voice_channel("_test_voice_two") + _m = await hook.send("Test", wait=True) + assert isinstance(_m, Message) + assert _m.webhook_id == hook.id + await hook.send("Test", username="Different Name", wait=True) + await hook.send("Test", avatar_url=bot.user.avatar.url, wait=True) + _m = await hook.send("Test", thread=test_thread, wait=True) + assert _m is not None + assert _m.channel == test_thread - vc = await test_channel.connect(deafened=True) - assert vc == bot.get_bot_voice_state(guild.id) + await hook.delete() - audio = AudioVolume("tests/test_audio.mp3") - vc.play_no_wait(audio) - await asyncio.sleep(2) +@pytest.mark.asyncio +async def test_voice(bot: Client, guild: Guild) -> None: + test_channel = await guild.create_voice_channel(f"_test_voice-{bot.suffix}") + test_channel_two = await guild.create_voice_channel(f"_test_voice_two-{bot.suffix}") + try: + try: + import nacl # noqa + except ImportError: + # testing on a non-voice extra + return - assert len(vc.current_audio.buffer) != 0 - assert vc.player._sent_payloads != 0 + vc = await test_channel.connect(deafened=True) + assert vc == bot.get_bot_voice_state(guild.id) - await vc.move(test_channel_two) - await asyncio.sleep(2) + audio = AudioVolume("tests/test_audio.mp3") + vc.play_no_wait(audio) + await asyncio.sleep(2) - _before = vc.player._sent_payloads + assert len(vc.current_audio.buffer) != 0 + assert vc.player._sent_payloads != 0 - await test_channel_two.connect(deafened=True) + await vc.move(test_channel_two) + await asyncio.sleep(2) - await asyncio.sleep(2) + _before = vc.player._sent_payloads - assert vc.player._sent_payloads != _before + await test_channel_two.connect(deafened=True) - vc.volume = 1 - await asyncio.sleep(1) - vc.volume = 0.5 + await asyncio.sleep(2) - vc.pause() - await asyncio.sleep(0.1) - assert vc.player.paused - vc.resume() - await asyncio.sleep(0.1) - assert not vc.player.paused + assert vc.player._sent_payloads != _before - await vc.disconnect() - await vc._close_connection() - await vc.ws._closed.wait() + vc.volume = 1 + await asyncio.sleep(1) + vc.volume = 0.5 + + vc.pause() + await asyncio.sleep(0.1) + assert vc.player.paused + vc.resume() + await asyncio.sleep(0.1) + assert not vc.player.paused + + await vc.disconnect() + await vc._close_connection() + await vc.ws._closed.wait() + finally: + await test_channel.delete() + await test_channel_two.delete() @pytest.mark.asyncio @@ -544,3 +559,5 @@ async def test_checks(bot: Client, guild: Guild) -> None: assert await naff.has_any_role(has_role)(context) is True assert await naff.has_any_role(lacks_role)(context) is False assert await naff.has_any_role(has_role)(generate_dummy_context(dm=True)) is False + + await member.remove_role(has_role) diff --git a/tests/test_emoji.py b/tests/test_emoji.py index 1514518ba..b5dad4c95 100644 --- a/tests/test_emoji.py +++ b/tests/test_emoji.py @@ -1,3 +1,7 @@ +import string + +import emoji + from naff.models.discord.emoji import PartialEmoji, process_emoji, process_emoji_req_format __all__ = () @@ -54,3 +58,83 @@ def test_emoji_processing() -> None: assert str(from_str) == raw_sample assert PartialEmoji.from_str("").animated is True + + +def test_unicode_recognition() -> None: + for _e in emoji.EMOJI_DATA: + assert PartialEmoji.from_str(_e) is not None + + +def test_regional_indicators() -> None: + regional_indicators = [ + "đŸ‡Ļ", + "🇧", + "🇨", + "🇩", + "đŸ‡Ē", + "đŸ‡Ģ", + "đŸ‡Ŧ", + "🇭", + "🇮", + "đŸ‡¯", + "🇰", + "🇱", + "🇲", + "đŸ‡ŗ", + "🇴", + "đŸ‡ĩ", + "đŸ‡ļ", + "🇷", + "🇸", + "🇹", + "đŸ‡ē", + "đŸ‡ģ", + "đŸ‡ŧ", + "đŸ‡Ŋ", + "🇾", + "đŸ‡ŋ", + ] + for _e in regional_indicators: + assert PartialEmoji.from_str(_e) is not None + + +def test_numerical_emoji() -> None: + numerical_emoji = ["0ī¸âƒŖ", "1ī¸âƒŖ", "2ī¸âƒŖ", "3ī¸âƒŖ", "4ī¸âƒŖ", "5ī¸âƒŖ", "6ī¸âƒŖ", "7ī¸âƒŖ", "8ī¸âƒŖ", "9ī¸âƒŖ"] + for _e in numerical_emoji: + assert PartialEmoji.from_str(_e) is not None + + +def test_false_positives() -> None: + for _e in string.printable: + assert PartialEmoji.from_str(_e) is None + + unicode_general_punctuation = [ + "’", + "‘", + "“", + "”", + "â€Ļ", + "–", + "—", + "â€ĸ", + "â—Ļ", + "â€Ŗ", + "⁃", + "⁎", + "⁏", + "⁒", + "⁓", + "âē", + "âģ", + "âŧ", + "âŊ", + "⁞", + "âŋ", + "₊", + "₋", + "₌", + "₍", + "₎", + ] + for _e in unicode_general_punctuation: + assert PartialEmoji.from_str(_e) is None