diff --git a/docs/migration.rst b/docs/migration.rst index 3d4eb129f..5a12042e2 100644 --- a/docs/migration.rst +++ b/docs/migration.rst @@ -66,7 +66,7 @@ portal and add the intent to your current intents when connecting: 4.1.0 → 4.3.0 ~~~~~~~~~~~~~~~ -A new big change in this release is the implementation of the ``get` utility method. +A new big change in this release is the implementation of the ``get`` utility method. It allows you to no longer use ``**await bot._http...``. You can get more information by reading the `get-documentation`_. diff --git a/interactions/client/get.py b/interactions/client/get.py index 05438f69d..fb012df97 100644 --- a/interactions/client/get.py +++ b/interactions/client/get.py @@ -1,8 +1,8 @@ from asyncio import sleep from enum import Enum -from inspect import isawaitable, isfunction +from inspect import isawaitable from logging import getLogger -from typing import Coroutine, Iterable, List, Optional, Type, TypeVar, Union, get_args +from typing import Coroutine, List, Optional, Type, TypeVar, Union, get_args try: from typing import _GenericAlias @@ -43,7 +43,7 @@ class Force(str, Enum): HTTP = "http" -def get(*args, **kwargs): +def get(client: Client, obj: Type[_T], **kwargs) -> Optional[_T]: """ Helper method to get an object. @@ -56,17 +56,11 @@ def get(*args, **kwargs): * purely from cache * purely from HTTP * from HTTP if not found in cache else from cache - * Get an object from an iterable - * based of its name - * based of its ID - * based of a custom check - * based of any other attribute the object inside the iterable has The method has to be awaited when: * You don't force anything * You force HTTP The method doesn't have to be awaited when: - * You get from an iterable * You force cache .. note :: @@ -93,22 +87,6 @@ def get(*args, **kwargs): enforce HTTP. To prevent this bug from happening it is suggested using ``force="http"`` instead of the enum. - Getting from an iterable: - - .. code-block:: python - - # Getting an object from an iterable - check = lambda role: role.name == "ADMIN" and role.color == 0xff0000 - roles = [ - interactions.Role(name="NOT ADMIN", color=0xff0000), - interactions.Role(name="ADMIN", color=0xff0000), - ] - role = get(roles, check=check) - # role will be `interactions.Role(name="ADMIN", color=0xff0000)` - - You can specify *any* attribute to check that the object could have (although only ``check``, ``id`` and - ``name`` are type-hinted) and the method will check for a match. - Getting an object: Here you will see two examples on how to get a single objects and the variations of how the object can be @@ -198,81 +176,68 @@ def _check(): def _check(): return False - if len(args) == 2: + if not isinstance(obj, type) and not isinstance(obj, _GenericAlias): + client: Client + obj: Union[Type[_T], Type[List[_T]]] + raise LibraryException(message="The object must not be an instance of a class!", code=12) - if any(isinstance(_, Iterable) for _ in args): - raise LibraryException( - message="You can only use Iterables as single-argument!", code=12 - ) + kwargs = _resolve_kwargs(obj, **kwargs) + http_name = f"get_{obj.__name__.lower()}" + kwarg_name = f"{obj.__name__.lower()}_id" + if isinstance(obj, _GenericAlias) or _check(): + _obj: Type[_T] = get_args(obj)[0] + _objects: List[Union[_obj, Coroutine]] = [] + kwarg_name += "s" - client, obj = args - if not isinstance(obj, type) and not isinstance(obj, _GenericAlias): - client: Client - obj: Union[Type[_T], Type[List[_T]]] - raise LibraryException( - message="The object must not be an instance of a class!", code=12 - ) - - kwargs = _resolve_kwargs(obj, **kwargs) - http_name = f"get_{obj.__name__.lower()}" - kwarg_name = f"{obj.__name__.lower()}_id" - if isinstance(obj, _GenericAlias) or _check(): - _obj: Type[_T] = get_args(obj)[0] - _objects: List[Union[_obj, Coroutine]] = [] - kwarg_name += "s" + force_cache = kwargs.pop("force", None) == "cache" + force_http = kwargs.pop("force", None) == "http" - force_cache = kwargs.pop("force", None) == "cache" - force_http = kwargs.pop("force", None) == "http" + if not force_http: + _objects = _get_cache(_obj, client, kwarg_name, _list=True, **kwargs) - if not force_http: - _objects = _get_cache(_obj, client, kwarg_name, _list=True, **kwargs) + if force_cache: + return _objects - if force_cache: - return _objects + elif not force_http and None not in _objects: + return _return_cache(_objects) - elif not force_http and None not in _objects: - return _return_cache(_objects) + elif force_http: + _objects.clear() + _func = getattr(client._http, http_name) + for _id in kwargs.get(kwarg_name): + _kwargs = kwargs + _kwargs.pop(kwarg_name) + _kwargs[kwarg_name[:-1]] = _id + _objects.append(_func(**_kwargs)) + return _http_request(_obj, http=client._http, request=_objects) - elif force_http: - _objects.clear() - _func = getattr(client._http, http_name) - for _id in kwargs.get(kwarg_name): + else: + _func = getattr(client._http, http_name) + for _index, __obj in enumerate(_objects): + if __obj is None: + _id = kwargs.get(kwarg_name)[_index] _kwargs = kwargs _kwargs.pop(kwarg_name) _kwargs[kwarg_name[:-1]] = _id - _objects.append(_func(**_kwargs)) - return _http_request(_obj, http=client._http, request=_objects) - - else: - _func = getattr(client._http, http_name) - for _index, __obj in enumerate(_objects): - if __obj is None: - _id = kwargs.get(kwarg_name)[_index] - _kwargs = kwargs - _kwargs.pop(kwarg_name) - _kwargs[kwarg_name[:-1]] = _id - _request = _func(**_kwargs) - _objects[_index] = _request - return _http_request(_obj, http=client._http, request=_objects) - - _obj: Optional[_T] = None + _request = _func(**_kwargs) + _objects[_index] = _request + return _http_request(_obj, http=client._http, request=_objects) - force_cache = kwargs.pop("force", None) == "cache" - force_http = kwargs.pop("force", None) == "http" - if not force_http: - _obj = _get_cache(obj, client, kwarg_name, **kwargs) + _obj: Optional[_T] = None - if force_cache: - return _obj + force_cache = kwargs.pop("force", None) == "cache" + force_http = kwargs.pop("force", None) == "http" + if not force_http: + _obj = _get_cache(obj, client, kwarg_name, **kwargs) - elif not force_http and _obj: - return _return_cache(_obj) + if force_cache: + return _obj - else: - return _http_request(obj=obj, http=client._http, _name=http_name, **kwargs) + elif not force_http and _obj: + return _return_cache(_obj) - elif len(args) == 1: - return _search_iterable(*args, **kwargs) + else: + return _http_request(obj=obj, http=client._http, _name=http_name, **kwargs) async def _http_request( @@ -339,40 +304,6 @@ def _get_cache( return _obj -def _search_iterable(items: Iterable[_T], **kwargs) -> Optional[_T]: - if not isinstance(items, Iterable): - raise LibraryException(message="The specified items must be an iterable!", code=12) - - if not kwargs: - raise LibraryException( - message="You have to specify either a custom check or a keyword argument to check against!", - code=12, - ) - - if len(list(kwargs)) > 1: - raise LibraryException( - message="Only one keyword argument to check against is allowed!", code=12 - ) - - _arg = str(list(kwargs)[0]) - kwarg = kwargs.get(_arg) - kwarg_is_function: bool = isfunction(kwarg) - - __obj = next( - ( - item - for item in items - if ( - str(getattr(item, _arg, None)) == str(kwarg) - if not kwarg_is_function - else kwarg(item) - ) - ), - None, - ) - return __obj - - def _resolve_kwargs(obj, **kwargs): # This function is needed to get correct kwarg names if __id := kwargs.pop("parent_id", None): diff --git a/interactions/client/get.pyi b/interactions/client/get.pyi index 9064aaae6..8e6d1522c 100644 --- a/interactions/client/get.pyi +++ b/interactions/client/get.pyi @@ -1,15 +1,15 @@ -from typing import overload, Type, TypeVar, List, Iterable, Optional, Callable, Awaitable, Literal, Coroutine, Union +from enum import Enum +from typing import Awaitable, Coroutine, List, Literal, Optional, Type, TypeVar, Union, overload from interactions.client.bot import Client -from enum import Enum +from ..api.http.client import HTTPClient from ..api.models.channel import Channel from ..api.models.guild import Guild from ..api.models.member import Member -from ..api.models.message import Message, Emoji, Sticker +from ..api.models.message import Emoji, Message, Sticker +from ..api.models.role import Role from ..api.models.user import User from ..api.models.webhook import Webhook -from ..api.models.role import Role -from ..api.http.client import HTTPClient _SA = TypeVar("_SA", Channel, Guild, Webhook, User, Sticker) _MA = TypeVar("_MA", Member, Emoji, Role, Message) @@ -24,11 +24,6 @@ class Force(str, Enum): CACHE: str HTTP: str -# not API-object related -@overload -def get( - items: Iterable[_T], /, *, id: Optional[int] = None, name: Optional[str] = None, check: Callable[..., bool], **kwargs -) -> Optional[_T]: ... # API-object related @@ -101,7 +96,6 @@ def get( # Having a not-overloaded definition stops showing a warning/complaint from the IDE if wrong arguments are put in, # so we'll leave that out -def _search_iterable(item: Iterable[_T], **kwargs) -> Optional[_T]:... def _get_cache( _object: Type[_T], client: Client, kwarg_name: str, _list: bool = False, **kwargs ) -> Union[Optional[_T], List[Optional[_T]]]:... diff --git a/interactions/client/models/utils.py b/interactions/client/models/utils.py index 8bfaa06c3..3422865a2 100644 --- a/interactions/client/models/utils.py +++ b/interactions/client/models/utils.py @@ -1,6 +1,6 @@ from asyncio import Task, get_running_loop, sleep from functools import wraps -from typing import TYPE_CHECKING, Awaitable, Callable, List, Union +from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, List, Optional, TypeVar, Union from ...api.error import LibraryException from .component import ActionRow, Button, SelectMenu @@ -8,9 +8,10 @@ if TYPE_CHECKING: from ..context import CommandContext - __all__ = ("autodefer", "spread_to_rows") +_T = TypeVar("_T") + def autodefer( delay: Union[float, int] = 2, @@ -138,3 +139,32 @@ async def command(ctx): raise LibraryException(code=12, message="Number of rows exceeds 5.") return rows + + +def search_iterable( + iterable: Iterable[_T], check: Optional[Callable[[_T], bool]] = None, /, **kwargs +) -> List[_T]: + """ + Searches through an iterable for items that: + - Are True for the check, if one is given + - Have attributes that match the keyword arguments (e.x. passing `id=your_id` will only return objects with that id) + + :param iterable: The iterable to search through + :type iterable: Iterable + :param check: The check that items will be checked against + :type check: Callable[[Any], bool] + :param kwargs: Any attributes the items should have + :type kwargs: Any + :return: All items that match the check and keywords + :rtype: list + """ + if check: + iterable = filter(check, iterable) + + if kwargs: + iterable = filter( + lambda item: all(getattr(item, attr) == value for attr, value in kwargs.items()), + iterable, + ) + + return list(iterable)