Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions interactions/client/models/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,59 +816,83 @@ async def command_error(ctx, error):
message=f"Your command needs at least {'three parameters to return self, context, and the' if self.extension else 'two parameter to return context and'} error.",
)

self.error_callback = self.__wrap_coro(coro)
self.error_callback = self.__wrap_coro(coro, error_callback=True)
return coro

async def __call(
self,
coro: Callable[..., Awaitable],
ctx: "CommandContext",
*args,
*args, # empty for now since all parameters are dispatched as kwargs
_name: Optional[str] = None,
_res: Optional[Union[BaseResult, GroupResult]] = None,
**kwargs,
) -> Optional[Any]:
"""Handles calling the coroutine based on parameter count."""
param_len = len(signature(coro).parameters)
opt_len = self.num_options.get(_name, len(args) + len(kwargs))
params = signature(coro).parameters
param_len = len(params)
opt_len = self.num_options.get(_name, len(args) + len(kwargs)) # options of slash command
last = params[list(params)[-1]] # last parameter
has_args = any(param.kind == param.VAR_POSITIONAL for param in params.values()) # any *args
index_of_var_pos = next(
(i for i, param in enumerate(params.values()) if param.kind == param.VAR_POSITIONAL),
param_len,
) # index of *args
par_opts = list(params.keys())[
(num := 2 if self.extension else 1) : (
-1 if last.kind in (last.VAR_POSITIONAL, last.VAR_KEYWORD) else index_of_var_pos
)
] # parameters that are before *args and **kwargs
keyword_only_args = list(params.keys())[index_of_var_pos:] # parameters after *args

try:
_coro = coro if hasattr(coro, "_wrapped") else self.__wrap_coro(coro)

if param_len < (2 if self.extension else 1):
if last.kind == last.VAR_KEYWORD: # foo(ctx, ..., **kwargs)
return await _coro(ctx, *args, **kwargs)
if last.kind == last.VAR_POSITIONAL: # foo(ctx, ..., *args)
return await _coro(
ctx,
*(kwargs[opt] for opt in par_opts if opt in kwargs),
*args,
)
if has_args: # foo(ctx, ..., *args, ..., **kwargs) OR foo(ctx, *args, ...)
return await _coro(
ctx,
*(kwargs[opt] for opt in par_opts if opt in kwargs), # pos before *args
*args,
*(
kwargs[opt]
for opt in kwargs
if opt not in par_opts and opt not in keyword_only_args
), # additional args
**{
opt: kwargs[opt]
for opt in kwargs
if opt not in par_opts and opt in keyword_only_args
}, # kwargs after *args
)

if param_len < num:
inner_msg: str = f"{num} parameter{'s' if num > 1 else ''} to return" + (
" self and" if self.extension else ""
)
raise LibraryException(
code=11,
message=f"Your command needs at least {'two parameters to return self and' if self.extension else 'one parameter to return'} context.",
code=11, message=f"Your command needs at least {inner_msg} context."
)

if param_len == (2 if self.extension else 1):
if param_len == num:
return await _coro(ctx)

if _res:
if param_len - opt_len == (2 if self.extension else 1):
if param_len - opt_len == num:
return await _coro(ctx, *args, **kwargs)
elif param_len - opt_len == (3 if self.extension else 2):
elif param_len - opt_len == num + 1:
return await _coro(ctx, _res, *args, **kwargs)

return await _coro(ctx, *args, **kwargs)
except CancelledError:
pass
except Exception as e:
if self.error_callback:
num_params = len(signature(self.error_callback).parameters)

if num_params == (3 if self.extension else 2):
await self.error_callback(ctx, e)
elif num_params == (4 if self.extension else 3):
await self.error_callback(ctx, e, _res)
else:
await self.error_callback(ctx, e, _res, *args, **kwargs)
elif self.listener and "on_command_error" in self.listener.events:
self.listener.dispatch("on_command_error", ctx, e)
else:
raise e

return StopCommand

def __check_command(self, command_type: str) -> None:
"""Checks if subcommands, groups, or autocompletions are created on context menus."""
Expand All @@ -895,7 +919,9 @@ async def __no_group(self, *args, **kwargs) -> None:
"""This is the coroutine used when no group coroutine is provided."""
pass

def __wrap_coro(self, coro: Callable[..., Awaitable]) -> Callable[..., Awaitable]:
def __wrap_coro(
self, coro: Callable[..., Awaitable], /, *, error_callback: bool = False
) -> Callable[..., Awaitable]:
"""Wraps a coroutine to make sure the :class:`interactions.client.bot.Extension` is passed to the coroutine, if any."""

@wraps(coro)
Expand All @@ -907,11 +933,28 @@ async def wrapper(ctx: "CommandContext", *args, **kwargs):
except CancelledError:
pass
except Exception as e:
if error_callback:
raise e
if self.error_callback:
num_params = len(signature(self.error_callback).parameters)

if num_params == (3 if self.extension else 2):
params = signature(self.error_callback).parameters
num_params = len(params)
last = params[list(params)[-1]]
num = 2 if self.extension else 1

if num_params == num:
await self.error_callback(ctx)
elif num_params == num + 1:
await self.error_callback(ctx, e)
elif last.kind == last.VAR_KEYWORD:
if num_params == num + 2:
await self.error_callback(ctx, e, **kwargs)
elif num_params >= num + 3:
await self.error_callback(ctx, e, *args, **kwargs)
elif last.kind == last.VAR_POSITIONAL:
if num_params == num + 2:
await self.error_callback(ctx, e, *args)
elif num_params >= num + 3:
await self.error_callback(ctx, e, *args, **kwargs)
else:
await self.error_callback(ctx, e, *args, **kwargs)
elif self.listener and "on_command_error" in self.listener.events:
Expand Down
7 changes: 4 additions & 3 deletions interactions/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..api.models.message import Message
from ..api.models.misc import Snowflake
from ..client.bot import Client, Extension
from ..client.context import CommandContext
from ..client.context import CommandContext # noqa F401

__all__ = (
"autodefer",
Expand Down Expand Up @@ -67,7 +67,7 @@ async def command(ctx):
"""

def decorator(coro: Callable[..., Union[Awaitable, Coroutine]]) -> Callable[..., Awaitable]:
from ..client.context import ComponentContext
from ..client.context import CommandContext, ComponentContext # noqa F811

@wraps(coro)
async def deferring_func(
Expand All @@ -80,7 +80,8 @@ async def deferring_func(

if isinstance(args[0], (ComponentContext, CommandContext)):
self = ctx
ctx = list(args).pop(0)
args = list(args)
ctx = args.pop(0)

task: Task = loop.create_task(coro(self, ctx, *args, **kwargs))

Expand Down