diff --git a/narwhals/_dispatch.py b/narwhals/_dispatch.py new file mode 100644 index 0000000000..b07702ce61 --- /dev/null +++ b/narwhals/_dispatch.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import partial +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, overload + +if TYPE_CHECKING: + + class Deferred(Protocol): + def __call__(self, f: Callable[..., R], /) -> JustDispatch[R]: ... + + +__all__ = ["just_dispatch"] + +R = TypeVar("R") +R_co = TypeVar("R_co", covariant=True) +PassthroughFn = TypeVar("PassthroughFn", bound=Callable[..., Any]) +"""Original function is passed-through unchanged.""" + + +class JustDispatch(Generic[R_co]): + """Single-dispatch wrapper produced by decorating a function with `@just_dispatch`.""" + + __slots__ = ("_registry", "_upper_bound") + + def __init__(self, function: Callable[..., R_co], /, upper_bound: type[Any]) -> None: + self._upper_bound: type[Any] = upper_bound + self._registry: dict[type[Any], Callable[..., R_co]] = {upper_bound: function} + + def dispatch(self, tp: type[Any], /) -> Callable[..., R_co]: + """Get the implementation for a given type.""" + if f := self._registry.get(tp): + return f + if issubclass(tp, self._upper_bound): + f = self._registry[tp] = self._registry[self._upper_bound] + return f + msg = f"{self._registry[self._upper_bound].__name__!r} does not support {tp.__name__!r}" + raise TypeError(msg) + + def register( + self, tp: type[Any], *tps: type[Any] + ) -> Callable[[PassthroughFn], PassthroughFn]: + """Register types to dispatch via the decorated function.""" + + def decorate(f: PassthroughFn, /) -> PassthroughFn: + self._registry.update((tp_, f) for tp_ in (tp, *tps)) + return f + + return decorate + + def __call__(self, arg: object, *args: Any, **kwds: Any) -> R_co: + """Dispatch on the type of the first argument, passing through all arguments.""" + return self.dispatch(arg.__class__)(arg, *args, **kwds) + + +@overload +def just_dispatch(function: Callable[..., R], /) -> JustDispatch[R]: ... +@overload +def just_dispatch(*, upper_bound: type[Any] = object) -> Deferred: ... +def just_dispatch( + function: Callable[..., R] | None = None, /, *, upper_bound: type[Any] = object +) -> JustDispatch[R] | Deferred: + """Transform a function into a single-dispatch generic function. + + An alternative take on [`@functools.singledispatch`]: + - without [MRO] fallback + - allows [*just*] the types registered and optionally an `upper_bound` + + Arguments: + function: Function to decorate, where the body serves as the default implementation. + upper_bound: When there is no registered implementation for a specific type, it must + be a subclass of `upper_bound` to use the default implementation. + + Tip: + `@just_dispatch` should only be used to decorate **internal functions** as we lose the docstring. + + [`@functools.singledispatch`]: https://docs.python.org/3/library/functools.html#functools.singledispatch + [MRO]: https://docs.python.org/3/howto/mro.html#python-2-3-mro + [*just*]: https://github.com/jorenham/optype/blob/e7221ed1d3d02989d5d01873323bac9f88459f26/README.md#just + """ + if function is not None: + return JustDispatch(function, upper_bound) + return partial(JustDispatch[Any], upper_bound=upper_bound) diff --git a/narwhals/dtypes/_supertyping.py b/narwhals/dtypes/_supertyping.py index df0df15835..d534f2bf93 100644 --- a/narwhals/dtypes/_supertyping.py +++ b/narwhals/dtypes/_supertyping.py @@ -13,9 +13,10 @@ from collections import deque from itertools import chain, product from operator import attrgetter -from typing import TYPE_CHECKING, Any, Final, Generic, cast +from typing import TYPE_CHECKING, Any from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND +from narwhals._dispatch import just_dispatch from narwhals._typing_compat import TypeVar from narwhals.dtypes._classes import ( Array, @@ -56,7 +57,6 @@ from typing_extensions import TypeAlias, TypeIs - from narwhals.dtypes import IntegerType from narwhals.dtypes._classes import _Bits from narwhals.typing import TimeUnit @@ -72,33 +72,11 @@ def lru_cache(*, maxsize: int) -> Callable[[_Fn], _Fn]: # noqa: ARG001 else: from functools import cache, lru_cache -Incomplete: TypeAlias = Any FrozenDTypes: TypeAlias = frozenset[type[DType]] DTypeGroup: TypeAlias = frozenset[type[DType]] -Nested: TypeAlias = "Array | List | Struct" -Parametric: TypeAlias = ( - "Datetime | DatetimeV1 | Decimal |Duration | DurationV1 | Enum | Nested" -) SameTemporalT = TypeVar("SameTemporalT", Datetime, DatetimeV1, Duration, DurationV1) """Temporal data types, with a `time_unit` attribute.""" - SameDatetimeT = TypeVar("SameDatetimeT", Datetime, DatetimeV1) -SameT = TypeVar( - "SameT", - Array, - List, - Struct, - Datetime, - DatetimeV1, - Decimal, - Duration, - DurationV1, - Enum, -) -DTypeT1 = TypeVar("DTypeT1", bound=DType) -DTypeT2 = TypeVar("DTypeT2", bound=DType, default=DTypeT1) -DTypeT1_co = TypeVar("DTypeT1_co", bound=DType, covariant=True) -DTypeT2_co = TypeVar("DTypeT2_co", bound=DType, covariant=True, default=DTypeT1_co) def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes: @@ -145,14 +123,6 @@ def _key_fn_time_unit(obj: Datetime | Duration, /) -> int: return _TIME_UNIT_PER_SECOND[obj.time_unit] -@lru_cache(maxsize=_CACHE_SIZE * 2) -def downcast_time_unit( - left: SameTemporalT, right: SameTemporalT, / -) -> SameTemporalT | None: - """Return the operand with the lowest precision time unit.""" - return min(left, right, key=_key_fn_time_unit) - - @lru_cache(maxsize=_CACHE_SIZE // 2) def dtype_eq(left: DType, right: DType, /) -> bool: return left == right @@ -216,6 +186,18 @@ def has_inner(dtype: Any) -> TypeIs[Array | List]: return isinstance(dtype, (Array, List)) +@just_dispatch(upper_bound=DType) +def same_supertype(left: DType, right: DType, /) -> DType | None: + return left if dtype_eq(left, right) else None + + +@same_supertype.register(Duration, DurationV1) +@lru_cache(maxsize=_CACHE_SIZE * 2) +def downcast_time_unit(left: SameTemporalT, right: SameTemporalT, /) -> SameTemporalT: + """Return the operand with the lowest precision time unit.""" + return min(left, right, key=_key_fn_time_unit) + + def _struct_fields_union( left: Collection[Field], right: Collection[Field], / ) -> Struct | None: @@ -233,7 +215,8 @@ def _struct_fields_union( return Struct(longest_map) -def _struct_supertype(left: Struct, right: Struct, /) -> Struct | None: +@same_supertype.register(Struct) +def struct_supertype(left: Struct, right: Struct, /) -> Struct | None: """Get the supertype of two struct data types. Adapted from [`super_type_structs`](https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/utils/supertype.rs#L588-L603) @@ -252,7 +235,8 @@ def _struct_supertype(left: Struct, right: Struct, /) -> Struct | None: return Struct(new_fields) -def _array_supertype(left: Array, right: Array, /) -> Array | None: +@same_supertype.register(Array) +def array_supertype(left: Array, right: Array, /) -> Array | None: if (left.shape == right.shape) and ( inner := get_supertype(left.inner(), right.inner()) ): @@ -260,13 +244,15 @@ def _array_supertype(left: Array, right: Array, /) -> Array | None: return None -def _list_supertype(left: List, right: List, /) -> List | None: +@same_supertype.register(List) +def list_supertype(left: List, right: List, /) -> List | None: if inner := get_supertype(left.inner(), right.inner()): return List(inner) return None -def _datetime_supertype( +@same_supertype.register(Datetime, DatetimeV1) +def datetime_supertype( left: SameDatetimeT, right: SameDatetimeT, / ) -> SameDatetimeT | None: if left.time_zone != right.time_zone: @@ -274,62 +260,23 @@ def _datetime_supertype( return downcast_time_unit(left, right) -def _enum_supertype(left: Enum, right: Enum, /) -> Enum | None: +@same_supertype.register(Enum) +def enum_supertype(left: Enum, right: Enum, /) -> Enum | None: return left if left.categories == right.categories else None -def _decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal: +@same_supertype.register(Decimal) +def decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal: # https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L508-L511 precision = max(left.precision, right.precision) scale = max(left.scale, right.scale) return Decimal(precision=precision, scale=scale) -_SAME_DISPATCH: Final[Mapping[type[Parametric], Callable[..., Incomplete | None]]] = { - Array: _array_supertype, - List: _list_supertype, - Struct: _struct_supertype, - Datetime: _datetime_supertype, - DatetimeV1: _datetime_supertype, - Duration: downcast_time_unit, - DurationV1: downcast_time_unit, - Enum: _enum_supertype, - Decimal: _decimal_supertype, -} -"""Specialized supertyping rules for `(T, T)`. - -*When operands share the same class*, all other data types can use `DType.__eq__` (see [#3393]). - -[#3393]: https://github.com/narwhals-dev/narwhals/pull/3393 -""" - - -def is_single_base_type( - st: _SupertypeCase[DTypeT1, DType], -) -> TypeIs[_SupertypeCase[DTypeT1]]: - return len(st.base_types) == 1 - - -def is_parametric_case( - st: _SupertypeCase[SameT | DType], -) -> TypeIs[_SupertypeCase[SameT]]: - return st.base_left in _SAME_DISPATCH - - -def _get_same_supertype_fn(base: type[SameT]) -> Callable[[SameT, SameT], SameT | None]: - return cast("Callable[[SameT, SameT], SameT | None]", _SAME_DISPATCH[base]) - - -def _same_supertype(st: _SupertypeCase[SameT | DType]) -> SameT | DType | None: - if is_parametric_case(st): - return _get_same_supertype_fn(st.base_left)(st.left, st.right) - return st.left if dtype_eq(st.left, st.right) else None - - DEC128_MAX_PREC = 38 # Precomputing powers of 10 up to 10^38 POW10_LIST = tuple(10**i for i in range(DEC128_MAX_PREC + 1)) -INT_MAX_MAP: Mapping[IntegerType, int] = { +INT_MAX_MAP: Mapping[Int, int] = { UInt8(): (2**8) - 1, UInt16(): (2**16) - 1, UInt32(): (2**32) - 1, @@ -349,7 +296,7 @@ def _integer_fits_in_decimal(value: int, precision: int, scale: int) -> bool: ) -def _decimal_integer_supertyping(decimal: Decimal, integer: IntegerType) -> DType | None: +def _decimal_integer_supertyping(decimal: Decimal, integer: Int) -> DType | None: precision, scale = decimal.precision, decimal.scale if integer in {UInt128(), Int128()}: @@ -364,15 +311,15 @@ def _decimal_integer_supertyping(decimal: Decimal, integer: IntegerType) -> DTyp return Decimal(precision, scale) -@lru_cache(maxsize=_CACHE_SIZE) -def _numeric_supertype(st: _SupertypeCase[DType]) -> DType | None: +def _numeric_supertype( + left: DType, right: DType, base_types: FrozenDTypes +) -> DType | None: """Get the supertype of two numeric data types that do not share the same class. `_{primitive_numeric,integer}_supertyping` define most valid numeric supertypes. We generate these on first use, with all subsequent calls returning the same mapping. """ - base_types = st.base_types if NUMERIC.issuperset(base_types): if INTEGER.issuperset(base_types): return _integer_supertyping()[base_types]() @@ -382,9 +329,7 @@ def _numeric_supertype(st: _SupertypeCase[DType]) -> DType | None: # Logic adapted from rust implementation # https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L517-L530 decimal, integer = ( - (st.left, st.right) - if isinstance(st.left, Decimal) - else (st.right, st.left) + (left, right) if isinstance(left, Decimal) else (right, left) ) return _decimal_integer_supertyping(decimal=decimal, integer=integer) # type: ignore[arg-type] @@ -404,47 +349,24 @@ def _mixed_nested_supertype(left: DType, right: DType, /) -> DType | None: return None -def _mixed_supertype(st: _SupertypeCase[DType, DType]) -> DType | None: +def _mixed_supertype( + left: DType, right: DType, base_types: FrozenDTypes, / +) -> DType | None: """Get the supertype of two data types that do not share the same class.""" - base_types = st.base_types - left, right = st.left, st.right if Date in base_types and _has_intersection(base_types, DATETIME): return left if isinstance(left, Datetime) else right if String in base_types: return (Binary if Binary in base_types else String)() if has_nested(base_types): return _mixed_nested_supertype(left, right) - return _numeric_supertype(st) if _has_intersection(NUMERIC, base_types) else None - - -class _SupertypeCase(Generic[DTypeT1_co, DTypeT2_co]): - """WIP.""" - - __slots__ = ("base_types", "left", "right") - - left: DTypeT1_co - right: DTypeT2_co - base_types: frozenset[type[DTypeT1_co | DTypeT2_co]] - - def __init__(self, left: DTypeT1_co, right: DTypeT2_co) -> None: - self.left = left - self.right = right - self.base_types = frozenset((self.base_left, self.base_right)) - - @property - def base_left(self) -> type[DTypeT1_co]: - return self.left.base_type() - - @property - def base_right(self) -> type[DTypeT2_co]: - return self.right.base_type() + return ( + _numeric_supertype(left, right, base_types) + if _has_intersection(NUMERIC, base_types) + else None + ) -# NOTE @dangotbanned: Tried **many** variants of this typing -# (to self) DO NOT TOUCH IT AGAIN -def get_supertype( - left: DTypeT1, right: DTypeT2 | DType -) -> DTypeT1 | DTypeT2 | DType | None: +def get_supertype(left: DType, right: DType) -> DType | None: """Given two data types, determine the data type that both types can reasonably safely be cast to. Arguments: @@ -454,9 +376,9 @@ def get_supertype( Returns: The common supertype that both types can be safely cast to, or None if no such type exists. """ - st_case = _SupertypeCase(left, right) - if Unknown in st_case.base_types: + base_types = frozen_dtypes(left.base_type(), right.base_type()) + if Unknown in base_types: return Unknown() - if is_single_base_type(st_case): - return _same_supertype(st_case) - return _mixed_supertype(st_case) + if len(base_types) == 1: + return same_supertype(left, right) + return _mixed_supertype(left, right, base_types) diff --git a/tests/dispatch_test.py b/tests/dispatch_test.py new file mode 100644 index 0000000000..2745ba6bf4 --- /dev/null +++ b/tests/dispatch_test.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import decimal +import string + +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals._dispatch import JustDispatch, just_dispatch +from narwhals.dtypes import DType, SignedIntegerType + + +def type_name(obj: object) -> str: + return type(obj).__name__ + + +@pytest.fixture +def dispatch_no_bound() -> JustDispatch[str]: + @just_dispatch + def dtype_repr_code(dtype: DType) -> str: + return type_name(dtype).lower() + + return dtype_repr_code + + +@pytest.fixture +def dispatch_upper_bound() -> JustDispatch[str]: + @just_dispatch(upper_bound=DType) + def dtype_repr_code(dtype: DType) -> str: + return type_name(dtype).lower() + + return dtype_repr_code + + +@pytest.fixture +def stdlib_decimal() -> decimal.Decimal: + return decimal.Decimal("1.0") + + +def test_just_dispatch( + dispatch_no_bound: JustDispatch[str], stdlib_decimal: decimal.Decimal +) -> None: + i64 = nw.Int64() + assert dispatch_no_bound(i64) == "int64" + + @dispatch_no_bound.register(*SignedIntegerType.__subclasses__()) + def repr_int(dtype: SignedIntegerType) -> str: + return f"i{type_name(dtype).strip(string.ascii_letters)}" + + assert dispatch_no_bound(i64) == "i64" + assert repr_int(i64) == "i64" + assert dispatch_no_bound(nw.UInt8()) == "uint8" + assert dispatch_no_bound(stdlib_decimal) == "decimal" + + +def test_just_dispatch_upper_bound( + dispatch_upper_bound: JustDispatch[str], stdlib_decimal: decimal.Decimal +) -> None: + i64 = nw.Int64() + assert dispatch_upper_bound(i64) == "int64" + + @dispatch_upper_bound.register(*SignedIntegerType.__subclasses__()) + def repr_int(dtype: SignedIntegerType) -> str: + return f"i{type_name(dtype).strip(string.ascii_letters)}" + + assert dispatch_upper_bound(i64) == "i64" + assert repr_int(i64) == "i64" + assert dispatch_upper_bound(nw.UInt8()) == "uint8" + assert dispatch_upper_bound(nw.Decimal()) == "decimal" + + with pytest.raises(TypeError, match=r"'dtype_repr_code' does not support 'Decimal'"): + dispatch_upper_bound(stdlib_decimal) + + dispatch_upper_bound.register(type(stdlib_decimal))(lambda _: "need to be explicit") + assert dispatch_upper_bound(stdlib_decimal) == "need to be explicit" + + +def test_just_dispatch_no_auto_subclass(dispatch_no_bound: JustDispatch[str]) -> None: + NOT_REGISTERED = "datetime" # noqa: N806 + TZ = "Africa/Accra" # noqa: N806 + + assert dispatch_no_bound(nw.Datetime("ms")) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime("us")) == NOT_REGISTERED + + @dispatch_no_bound.register(nw.Datetime) + def repr_datetime(dtype: nw.Datetime) -> str: + if dtype.time_zone is None: + args: str = dtype.time_unit + else: + args = f"{dtype.time_unit}, {dtype.time_zone}" + return f"datetime[{args}]" + + assert dispatch_no_bound(nw.Datetime()) == "datetime[us]" + assert dispatch_no_bound(nw.Datetime("s")) == "datetime[s]" + assert dispatch_no_bound(nw.Datetime(time_zone=TZ)) == f"datetime[us, {TZ}]" + + assert dispatch_no_bound(nw_v1.Datetime()) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime("s")) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime(time_zone=TZ)) == NOT_REGISTERED + + dispatch_no_bound.register(nw_v1.Datetime)(repr_datetime) + + assert dispatch_no_bound(nw_v1.Datetime()) == "datetime[us]" + assert dispatch_no_bound(nw_v1.Datetime("s")) == "datetime[s]" + assert dispatch_no_bound(nw_v1.Datetime(time_zone=TZ)) == f"datetime[us, {TZ}]"