Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions interactions/api/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import AbstractEventLoop, get_event_loop
from asyncio import AbstractEventLoop, Future, get_event_loop
from logging import Logger
from typing import Callable, Coroutine, Dict, List, Optional

Expand All @@ -17,24 +17,22 @@ class Listener:
:ivar dict events: A list of events being dispatched.
"""

__slots__ = ("loop", "events")
__slots__ = ("loop", "events", "extra_events")

def __init__(self) -> None:
self.loop: AbstractEventLoop = get_event_loop()
self.events: Dict[str, List[Callable[..., Coroutine]]] = {}
self.extra_events: Dict[str, List[Future]] = {} # used in `Client.wait_for`

def dispatch(self, __name: str, *args, **kwargs) -> None:
def dispatch(self, name: str, /, *args, **kwargs) -> None:
r"""
Dispatches an event given out by the gateway.

:param __name: The name of the event to dispatch.
:type __name: str
:param *args: Multiple arguments of the coroutine.
:type *args: list[Any]
:param **kwargs: Keyword-only arguments of the coroutine.
:type **kwargs: dict
:param str name: The name of the event to dispatch.
:param list[Any] \*args: Multiple arguments of the coroutine.
:param dict \**kwargs: Keyword-only arguments of the coroutine.
"""
for event in self.events.get(__name, []):
for event in self.events.get(name, []):
converters: dict
if converters := getattr(event, "_converters", None):
_kwargs = kwargs.copy()
Expand All @@ -47,6 +45,23 @@ def dispatch(self, __name: str, *args, **kwargs) -> None:
self.loop.create_task(event(*args, **kwargs))
log.debug(f"DISPATCH: {event}")

# wait_for events
futs = self.extra_events.get(name, [])
if not futs:
return

log.debug(f"Resolving {len(futs)} futures")

for fut in futs:
if fut.done():
log.debug(
f"A future for the {name} event was already {'cancelled' if fut.cancelled() else 'resolved'}"
)
else:
fut.set_result(args)

self.extra_events[name] = []

def register(self, coro: Callable[..., Coroutine], name: Optional[str] = None) -> None:
"""
Registers a given coroutine as an event to be listened to.
Expand All @@ -66,3 +81,17 @@ def register(self, coro: Callable[..., Coroutine], name: Optional[str] = None) -

self.events[_name] = event
log.debug(f"REGISTER: {self.events[_name]}")

def add(self, name: str) -> Future:
"""
Returns a Future that will resolve whenever the supplied event is dispatched

:param str name: The event to listen for
:return: A future that will be resolved on the next event dispatch with the data given
:rtype: asyncio.Future
"""
fut = self.loop.create_future()
futures = self.extra_events.get(name, [])
futures.append(fut)
self.extra_events[name] = futures
return fut
166 changes: 162 additions & 4 deletions interactions/client/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,32 @@
import logging
import re
import sys
from asyncio import AbstractEventLoop, CancelledError, get_event_loop, iscoroutinefunction
from asyncio import AbstractEventLoop, CancelledError, get_event_loop, iscoroutinefunction, wait_for
from functools import wraps
from importlib import import_module
from importlib.util import resolve_name
from inspect import getmembers
from inspect import getmembers, isawaitable
from types import ModuleType
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, Union

from ..api import WebSocketClient as WSClient
from ..api.error import LibraryException
from ..api.http.client import HTTPClient
from ..api.models.flags import Intents, Permissions
from ..api.models.guild import Guild
from ..api.models.message import Message
from ..api.models.misc import Image, Snowflake
from ..api.models.presence import ClientPresence
from ..api.models.team import Application
from ..api.models.user import User
from ..base import get_logger
from ..utils.attrs_utils import convert_list
from ..utils.missing import MISSING
from .context import CommandContext, ComponentContext
from .decor import component as _component
from .enums import ApplicationCommandType, Locale, OptionType
from .models.command import ApplicationCommand, Choice, Command, Option
from .models.component import Button, Modal, SelectMenu
from .models.component import ActionRow, Button, Modal, SelectMenu

log: logging.Logger = get_logger("client")

Expand Down Expand Up @@ -1539,6 +1541,162 @@ async def _logout(self) -> None:
await self._websocket.close()
await self._http._req.close()

async def wait_for(
self,
name: str,
check: Optional[Callable[..., Union[bool, Awaitable[bool]]]] = None,
timeout: Optional[float] = None,
) -> Any:
"""
Waits for an event once, and returns the result.

Unlike event decorators, this is not persistent, and can be used to only proceed in a command once an event happens.

:param str name: The event to wait for
:param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned
:param float timeout: How long to wait for the event before raising an error
:return: The value of the dispatched event
:rtype: Any
"""
while True:
fut = self._websocket._dispatch.add(name=name)
try:
# asyncio's wait_for
res: tuple = await wait_for(fut, timeout=timeout)
except TimeoutError:
with contextlib.suppress(ValueError):
self._websocket._dispatch.extra_events[name].remove(fut)
raise

if not check:
break
checked = check(*res)
if isawaitable(checked):
checked = await checked
if checked:
break
else:
# The check failed, so try again next time
log.info(f"A check failed waiting for the {name} event")

if res:
return res[0] if len(res) == 1 else res

async def wait_for_component(
self,
components: Union[
Union[str, Button, SelectMenu],
List[Union[str, Button, SelectMenu]],
] = None,
messages: Union[Message, int, List[Union[Message, int]]] = None,
check: Optional[Callable[..., Union[bool, Awaitable[bool]]]] = None,
timeout: Optional[float] = None,
):
"""
Waits for a component to be interacted with, and returns the resulting context.

:param Union[str, interactions.Button, interactions.SelectMenu, List[Union[str, interactions.Button, interactions.SelectMenu]]] components: The component(s) to wait for
:param Union[interactions.Message, int, List[Union[interactions.Message, int]]] messages: The message(s) to check for
:param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned
:param float timeout: How long to wait for the event before raising an error
:return: The ComponentContext of the dispatched event
:rtype: interactions.ComponentContext
"""
custom_ids: List[str] = []
messages_ids: List[int] = []

if components:
if isinstance(components, list):
for component in components:
if isinstance(component, (Button, SelectMenu)):
custom_ids.append(component.custom_id)
elif isinstance(component, ActionRow):
custom_ids.extend([c.custom_id for c in component.components])
elif isinstance(component, list):
for c in component:
if isinstance(c, (Button, SelectMenu)):
custom_ids.append(c.custom_id)
elif isinstance(c, ActionRow):
custom_ids.extend([b.custom_id for b in c.components])
elif isinstance(c, str):
custom_ids.append(c)
elif isinstance(component, str):
custom_ids.append(component)
elif isinstance(components, (Button, SelectMenu)):
custom_ids.append(components.custom_id)
elif isinstance(components, ActionRow):
custom_ids.extend([c.custom_id for c in components.components]) # noqa
elif isinstance(components, str):
custom_ids.append(components)

if messages:
if isinstance(messages, Message):
messages_ids.append(int(messages.id))
elif isinstance(messages, list):
for message in messages:
if isinstance(message, Message):
messages_ids.append(int(message.id))
else:
messages_ids.append(int(message))
else: # account for plain ints, string, or Snowflakes
messages_ids.append(int(messages))

def _check(ctx: ComponentContext) -> bool:
if custom_ids and ctx.data.custom_id not in custom_ids:
return False
if messages_ids and int(ctx.message.id) not in messages_ids:
return False
return check(ctx) if check else True

return await self.wait_for("on_component", check=_check, timeout=timeout)

async def wait_for_modal(
self,
modals: Union[Modal, str, List[Union[Modal, str]]],
check: Optional[Callable[[CommandContext], bool]] = None,
timeout: Optional[float] = None,
):
"""
Waits for a modal to be interacted with, and returns the resulting context.

:param Union[Modal, str, List[Modal, str]] modals: The modal(s) to wait for
:param Callable check: A function or coroutine to call, which should return a truthy value if the data should be returned
:param Optional[float] timeout: How long to wait for the event before raising an error
:return: The context of the modal
:rtype: CommandContext
"""
ids: List[str] = []

if isinstance(modals, Modal):
ids = [str(modals.custom_id)]
elif isinstance(modals, str):
ids = [modals]
elif isinstance(modals, list):
for modal in modals:
if isinstance(modal, Modal):
ids.append(str(modal.custom_id))
elif isinstance(modal, str):
modals.append(modal)

if not all(isinstance(id, str) for id in ids):
raise TypeError("No modals were passed!")

def _check(ctx: CommandContext):
if ids and ctx.data.custom_id not in ids:
return False
return check(ctx) if check else True

ctx: CommandContext = await self.wait_for("on_modal", check=_check, timeout=timeout)

# Ed requested that it returns a result similar to the decorator
fields: List[str] = []
for actionrow in ctx.data.components: # discord is weird with this
if actionrow.components:
data = actionrow.components[0].value
fields.append(data)

return ctx, *fields


# TODO: Implement the rest of cog behaviour when possible.
class Extension:
Expand Down