From 4e8962d5837293bd9af834b2375cf2bb97a749ee Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:40:19 -0500 Subject: [PATCH] fix: don't pass kwargs to modal callbacks that only have ctx Fixes #1517. Supercedes #1584. --- interactions/client/client.py | 11 +++++++++++ interactions/models/internal/application_commands.py | 7 ++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/interactions/client/client.py b/interactions/client/client.py index 2d4da248f..cbde28bee 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -1344,6 +1344,17 @@ def add_modal_callback(self, command: ModalCommand) -> None: Args: command: The command to add """ + # test for parameters that arent the ctx (or self) + if command.has_binding: + callback = functools.partial(command.callback, None, None) + else: + callback = functools.partial(command.callback, None) + + if not inspect.signature(callback).parameters: + # if there are none, notify the command to just pass the ctx and not kwargs + # TODO: just make modal callbacks not pass kwargs at all (breaking) + command._just_ctx = True + for listener in command.listeners: if isinstance(listener, re.Pattern): if listener in self._regex_component_callbacks.keys(): diff --git a/interactions/models/internal/application_commands.py b/interactions/models/internal/application_commands.py index b6cde4f0d..2fb5c4e99 100644 --- a/interactions/models/internal/application_commands.py +++ b/interactions/models/internal/application_commands.py @@ -844,7 +844,12 @@ class ComponentCommand(InteractionCommand): @attrs.define(eq=False, order=False, hash=False, kw_only=True) class ModalCommand(ComponentCommand): - ... + _just_ctx: bool = attrs.field(repr=False, default=False) + + async def call_callback(self, callback: Callable, context: "BaseContext") -> None: + if self._just_ctx: + return await self.call_with_binding(callback, context) + return await super().call_callback(callback, context) def _unpack_helper(iterable: typing.Iterable[str]) -> list[str]: