diff --git a/interactions/api/dispatch.py b/interactions/api/dispatch.py index c35197d46..5060cda63 100644 --- a/interactions/api/dispatch.py +++ b/interactions/api/dispatch.py @@ -1,4 +1,4 @@ -from asyncio import AbstractEventLoop, get_event_loop +from asyncio import AbstractEventLoop, Future, get_event_loop from logging import Logger from typing import Callable, Coroutine, Dict, List, Optional @@ -17,24 +17,22 @@ class Listener: :ivar dict events: A list of events being dispatched. """ - __slots__ = ("loop", "events") + __slots__ = ("loop", "events", "extra_events") def __init__(self) -> None: self.loop: AbstractEventLoop = get_event_loop() self.events: Dict[str, List[Callable[..., Coroutine]]] = {} + self.extra_events: Dict[str, List[Future]] = {} # used in `Client.wait_for` - def dispatch(self, __name: str, *args, **kwargs) -> None: + def dispatch(self, name: str, /, *args, **kwargs) -> None: r""" Dispatches an event given out by the gateway. - :param __name: The name of the event to dispatch. - :type __name: str - :param *args: Multiple arguments of the coroutine. - :type *args: list[Any] - :param **kwargs: Keyword-only arguments of the coroutine. - :type **kwargs: dict + :param str name: The name of the event to dispatch. + :param list[Any] \*args: Multiple arguments of the coroutine. + :param dict \**kwargs: Keyword-only arguments of the coroutine. """ - for event in self.events.get(__name, []): + for event in self.events.get(name, []): converters: dict if converters := getattr(event, "_converters", None): _kwargs = kwargs.copy() @@ -47,6 +45,23 @@ def dispatch(self, __name: str, *args, **kwargs) -> None: self.loop.create_task(event(*args, **kwargs)) log.debug(f"DISPATCH: {event}") + # wait_for events + futs = self.extra_events.get(name, []) + if not futs: + return + + log.debug(f"Resolving {len(futs)} futures") + + for fut in futs: + if fut.done(): + log.debug( + f"A future for the {name} event was already {'cancelled' if fut.cancelled() else 'resolved'}" + ) + else: + fut.set_result(args) + + self.extra_events[name] = [] + def register(self, coro: Callable[..., Coroutine], name: Optional[str] = None) -> None: """ Registers a given coroutine as an event to be listened to. @@ -66,3 +81,17 @@ def register(self, coro: Callable[..., Coroutine], name: Optional[str] = None) - self.events[_name] = event log.debug(f"REGISTER: {self.events[_name]}") + + def add(self, name: str) -> Future: + """ + Returns a Future that will resolve whenever the supplied event is dispatched + + :param str name: The event to listen for + :return: A future that will be resolved on the next event dispatch with the data given + :rtype: asyncio.Future + """ + fut = self.loop.create_future() + futures = self.extra_events.get(name, []) + futures.append(fut) + self.extra_events[name] = futures + return fut diff --git a/interactions/client/bot.py b/interactions/client/bot.py index 480ba068e..84abf0f69 100644 --- a/interactions/client/bot.py +++ b/interactions/client/bot.py @@ -2,19 +2,20 @@ import logging import re import sys -from asyncio import AbstractEventLoop, CancelledError, get_event_loop, iscoroutinefunction +from asyncio import AbstractEventLoop, CancelledError, get_event_loop, iscoroutinefunction, wait_for from functools import wraps from importlib import import_module from importlib.util import resolve_name -from inspect import getmembers +from inspect import getmembers, isawaitable from types import ModuleType -from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, Union from ..api import WebSocketClient as WSClient from ..api.error import LibraryException from ..api.http.client import HTTPClient from ..api.models.flags import Intents, Permissions from ..api.models.guild import Guild +from ..api.models.message import Message from ..api.models.misc import Image, Snowflake from ..api.models.presence import ClientPresence from ..api.models.team import Application @@ -22,10 +23,11 @@ from ..base import get_logger from ..utils.attrs_utils import convert_list from ..utils.missing import MISSING +from .context import CommandContext, ComponentContext from .decor import component as _component from .enums import ApplicationCommandType, Locale, OptionType from .models.command import ApplicationCommand, Choice, Command, Option -from .models.component import Button, Modal, SelectMenu +from .models.component import ActionRow, Button, Modal, SelectMenu log: logging.Logger = get_logger("client") @@ -1539,6 +1541,173 @@ async def _logout(self) -> None: await self._websocket.close() await self._http._req.close() + async def wait_for( + self, + name: str, + check: Optional[Callable[..., Union[bool, Awaitable[bool]]]] = None, + timeout: Optional[float] = None, + ) -> Any: + """ + Waits for an event once, and returns the result. + + Unlike event decorators, this is not persistent, and can be used to only proceed in a command once an event happens. + + :param str name: The event to wait for + :param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned + :param float timeout: How long to wait for the event before raising an error + :return: The value of the dispatched event + :rtype: Any + """ + while True: + fut = self._websocket._dispatch.add(name=name) + try: + # asyncio's wait_for + res: tuple = await wait_for(fut, timeout=timeout) + except TimeoutError: + with contextlib.suppress(ValueError): + self._websocket._dispatch.extra_events[name].remove(fut) + raise + + if not check: + break + checked = check(*res) + if isawaitable(checked): + checked = await checked + if checked: + break + else: + # The check failed, so try again next time + log.info(f"A check failed waiting for the {name} event") + + if res: + return res[0] if len(res) == 1 else res + + async def wait_for_component( + self, + components: Union[ + Union[str, Button, SelectMenu], + List[Union[str, Button, SelectMenu]], + ] = None, + messages: Union[Message, int, List[Union[Message, int]]] = None, + check: Optional[Callable[..., Union[bool, Awaitable[bool]]]] = None, + timeout: Optional[float] = None, + ) -> ComponentContext: + """ + Waits for a component to be interacted with, and returns the resulting context. + + .. note:: + If you are waiting for a select menu, you can find the selected values in ``ctx.data.values`` + + :param Union[str, interactions.Button, interactions.SelectMenu, List[Union[str, interactions.Button, interactions.SelectMenu]]] components: The component(s) to wait for + :param Union[interactions.Message, int, List[Union[interactions.Message, int]]] messages: The message(s) to check for + :param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned + :param float timeout: How long to wait for the event before raising an error + :return: The ComponentContext of the dispatched event + :rtype: ComponentContext + """ + custom_ids: List[str] = [] + messages_ids: List[int] = [] + + if components: + if isinstance(components, list): + for component in components: + if isinstance(component, (Button, SelectMenu)): + custom_ids.append(component.custom_id) + elif isinstance(component, ActionRow): + custom_ids.extend([c.custom_id for c in component.components]) + elif isinstance(component, list): + for c in component: + if isinstance(c, (Button, SelectMenu)): + custom_ids.append(c.custom_id) + elif isinstance(c, ActionRow): + custom_ids.extend([b.custom_id for b in c.components]) + elif isinstance(c, str): + custom_ids.append(c) + elif isinstance(component, str): + custom_ids.append(component) + elif isinstance(components, (Button, SelectMenu)): + custom_ids.append(components.custom_id) + elif isinstance(components, ActionRow): + custom_ids.extend([c.custom_id for c in components.components]) # noqa + elif isinstance(components, str): + custom_ids.append(components) + + if messages: + if isinstance(messages, Message): + messages_ids.append(int(messages.id)) + elif isinstance(messages, list): + for message in messages: + if isinstance(message, Message): + messages_ids.append(int(message.id)) + else: + messages_ids.append(int(message)) + else: # account for plain ints, string, or Snowflakes + messages_ids.append(int(messages)) + + def _check(ctx: ComponentContext) -> bool: + if custom_ids and ctx.data.custom_id not in custom_ids: + return False + if messages_ids and int(ctx.message.id) not in messages_ids: + return False + return check(ctx) if check else True + + return await self.wait_for("on_component", check=_check, timeout=timeout) + + async def wait_for_modal( + self, + modals: Union[Modal, str, List[Union[Modal, str]]], + check: Optional[Callable[[CommandContext], bool]] = None, + timeout: Optional[float] = None, + ) -> Tuple[CommandContext, List[str]]: + """ + Waits for a modal to be interacted with, and returns the resulting context and submitted data. + + .. note:: + This function returns both the context of the modal and the data the user input. + The recommended way to use it is to do: + ``modal_ctx, fields = await bot.wait_for_modal(...)`` + + Alternatively, to get the fields immediately, you can do: + ``modal_ctx, (field1, field2, ...) = await bot.wait_for_modal(...)`` + + :param Union[Modal, str, List[Modal, str]] modals: The modal(s) to wait for + :param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned + :param Optional[float] timeout: How long to wait for the event before raising an error + :return: The context of the modal, followed by the data the user inputted + :rtype: tuple[CommandContext, list[str]] + """ + ids: List[str] = [] + + if isinstance(modals, Modal): + ids = [str(modals.custom_id)] + elif isinstance(modals, str): + ids = [modals] + elif isinstance(modals, list): + for modal in modals: + if isinstance(modal, Modal): + ids.append(str(modal.custom_id)) + elif isinstance(modal, str): + modals.append(modal) + + if not all(isinstance(id, str) for id in ids): + raise TypeError("No modals were passed!") + + def _check(ctx: CommandContext): + if ids and ctx.data.custom_id not in ids: + return False + return check(ctx) if check else True + + ctx: CommandContext = await self.wait_for("on_modal", check=_check, timeout=timeout) + + # Ed requested that it returns a result similar to the decorator + fields: List[str] = [] + for actionrow in ctx.data.components: # discord is weird with this + if actionrow.components: + data = actionrow.components[0].value + fields.append(data) + + return ctx, fields + # TODO: Implement the rest of cog behaviour when possible. class Extension: