Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions interactions/utils/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum
from inspect import isawaitable
from logging import getLogger
from sys import version_info
from typing import TYPE_CHECKING, Coroutine, List, Optional, Type, TypeVar, Union, get_args

try:
Expand All @@ -10,7 +11,14 @@
except ImportError:
from typing import _BaseGenericAlias as _GenericAlias

from sys import version_info

if version_info < (3, 9):

class GenericAlias:
...

else:
from types import GenericAlias

from ..api.error import LibraryException
from ..api.models.emoji import Emoji
Expand Down Expand Up @@ -123,13 +131,13 @@ def get(client: "Client", obj: Type[_T], **kwargs) -> Optional[_T]:

# with http force
member = await get(
client, interactions.Member, parent_id=your_guild_id, object_id=your_member_id
client, interactions.Member, parent_id=your_guild_id, object_id=your_member_id, force="http",
)
# always has a value

# with cache force
member = await get(
client, interactions.Member, parent_id=your_guild_id, object_id=your_member_id
member = get(
client, interactions.Member, parent_id=your_guild_id, object_id=your_member_id, force="cache",
)
# because of cache only, this can be None

Expand Down Expand Up @@ -166,21 +174,7 @@ def get(client: "Client", obj: Type[_T], **kwargs) -> Optional[_T]:

"""

if version_info >= (3, 9):

def _check():
return (
obj == list[get_args(obj)[0]]
if isinstance(get_args(obj), tuple) and get_args(obj)
else False
)

else:

def _check():
return False

if not isinstance(obj, type) and not isinstance(obj, _GenericAlias):
if not isinstance(obj, type) and not isinstance(obj, (_GenericAlias, GenericAlias)):
raise LibraryException(message="The object must not be an instance of a class!", code=12)

client: "Client"
Expand All @@ -191,14 +185,13 @@ def _check():
force_cache = force_arg == "cache"
force_http = force_arg == "http"

if isinstance(obj, _GenericAlias) or _check():
if isinstance(obj, (_GenericAlias, GenericAlias)):
_obj: Type[_T] = get_args(obj)[0]
http_name = f"get_{_obj.__name__.lower()}"
kwarg_name = f"{_obj.__name__.lower()}_ids"
_objects: List[Union[_obj, Coroutine]] = []

if not force_http:
_objects = _get_cache(_obj, client, kwarg_name, _list=True, **kwargs)
_objects: List[Union[_obj, Coroutine]] = (
_get_cache(_obj, client, kwarg_name, _list=True, **kwargs) if force_http else []
) # some sourcery stuff i dunno

if force_cache:
return _objects
Expand Down Expand Up @@ -231,10 +224,9 @@ def _check():
http_name = f"get_{obj.__name__.lower()}"
kwarg_name = f"{obj.__name__.lower()}_id"

_obj: Optional[_T] = None

if not force_http:
_obj = _get_cache(obj, client, kwarg_name, **kwargs)
_obj: Optional[_T] = (
None if force_http else _get_cache(obj, client, kwarg_name, **kwargs)
) # more sourcery stuff

if force_cache:
return _obj
Expand Down