Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b0df54e
refactor: Replace `_same_supertype` with a custom `@singledispatch`
dangotbanned Jan 20, 2026
c544b10
refactor: Just use a real class
dangotbanned Jan 20, 2026
02b7811
fix(typing): Satisfy `mypy`
dangotbanned Jan 20, 2026
56d3459
fix: Oops forgot the first element
dangotbanned Jan 20, 2026
ab71793
refactor(typing): Use slightly better names
dangotbanned Jan 20, 2026
a58d070
chore: Rename `default` -> `upper_bound`
dangotbanned Jan 20, 2026
c4c21f8
docs: Replace debugging doc
dangotbanned Jan 20, 2026
e4f7bf1
docs: More cleanup
dangotbanned Jan 20, 2026
01fbf85
refactor: Use `__slots__`, remove a field
dangotbanned Jan 20, 2026
d88d50b
docs: More, more cleanup
dangotbanned Jan 20, 2026
659f5c7
docs: lil bit of `.register` progress
dangotbanned Jan 20, 2026
2b01b2b
cov
dangotbanned Jan 20, 2026
5450652
test: Get full coverage for `@just_dispatch`
dangotbanned Jan 21, 2026
238f069
chore: Give it a simple repr
dangotbanned Jan 21, 2026
301f537
test: Oops, forgot that was an override
dangotbanned Jan 21, 2026
22c8029
Merge remote-tracking branch 'upstream/dtypes/supertyping' into dtype…
dangotbanned Jan 22, 2026
bc39c72
Merge branch 'dtypes/supertyping' into dtypes/supertyping-dispatch
dangotbanned Jan 24, 2026
3fc33c6
Merge remote-tracking branch 'upstream/dtypes/supertyping' into dtype…
dangotbanned Feb 3, 2026
14f81fc
revert: Keep only what is required
dangotbanned Feb 3, 2026
e4c6657
Merge remote-tracking branch 'upstream/dtypes/supertyping' into dtype…
dangotbanned Feb 3, 2026
802d939
refactor: Simplify `@just_dispatch` signature
dangotbanned Feb 3, 2026
308f389
fix(typing): Satisfy mypy
dangotbanned Feb 3, 2026
08f762a
test: Gotta get that coverage
dangotbanned Feb 16, 2026
ce3e730
Merge branch 'dtypes/supertyping' into dtypes/supertyping-dispatch
dangotbanned Feb 16, 2026
18d6273
docs: Restore a minimal version of `@just_dispatch` doc
dangotbanned Feb 16, 2026
ddd5d98
revert: Remove `Impl` alias
dangotbanned Feb 16, 2026
55c393e
refactor: Rename `Passthrough` -> `PassthroughFn`
dangotbanned Feb 20, 2026
2757f40
docs: Add note to use only on internal
dangotbanned Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions narwhals/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, overload

if TYPE_CHECKING:

class Deferred(Protocol):
def __call__(self, f: Callable[..., R], /) -> JustDispatch[R]: ...


__all__ = ["just_dispatch"]

R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
PassthroughFn = TypeVar("PassthroughFn", bound=Callable[..., Any])
"""Original function is passed-through unchanged."""


class JustDispatch(Generic[R_co]):
"""Single-dispatch wrapper produced by decorating a function with `@just_dispatch`."""

__slots__ = ("_registry", "_upper_bound")

def __init__(self, function: Callable[..., R_co], /, upper_bound: type[Any]) -> None:
self._upper_bound: type[Any] = upper_bound
self._registry: dict[type[Any], Callable[..., R_co]] = {upper_bound: function}

def dispatch(self, tp: type[Any], /) -> Callable[..., R_co]:
"""Get the implementation for a given type."""
if f := self._registry.get(tp):
return f
if issubclass(tp, self._upper_bound):
f = self._registry[tp] = self._registry[self._upper_bound]
return f
msg = f"{self._registry[self._upper_bound].__name__!r} does not support {tp.__name__!r}"
raise TypeError(msg)

def register(
self, tp: type[Any], *tps: type[Any]
) -> Callable[[PassthroughFn], PassthroughFn]:
"""Register types to dispatch via the decorated function."""

def decorate(f: PassthroughFn, /) -> PassthroughFn:
self._registry.update((tp_, f) for tp_ in (tp, *tps))
return f

return decorate

def __call__(self, arg: object, *args: Any, **kwds: Any) -> R_co:
"""Dispatch on the type of the first argument, passing through all arguments."""
return self.dispatch(arg.__class__)(arg, *args, **kwds)


@overload
def just_dispatch(function: Callable[..., R], /) -> JustDispatch[R]: ...
@overload
def just_dispatch(*, upper_bound: type[Any] = object) -> Deferred: ...
def just_dispatch(
function: Callable[..., R] | None = None, /, *, upper_bound: type[Any] = object
) -> JustDispatch[R] | Deferred:
"""Transform a function into a single-dispatch generic function.

An alternative take on [`@functools.singledispatch`]:
- without [MRO] fallback
- allows [*just*] the types registered and optionally an `upper_bound`

Arguments:
function: Function to decorate, where the body serves as the default implementation.
upper_bound: When there is no registered implementation for a specific type, it must
be a subclass of `upper_bound` to use the default implementation.

Tip:
`@just_dispatch` should only be used to decorate **internal functions** as we lose the docstring.

[`@functools.singledispatch`]: https://docs.python.org/3/library/functools.html#functools.singledispatch
[MRO]: https://docs.python.org/3/howto/mro.html#python-2-3-mro
[*just*]: https://github.com/jorenham/optype/blob/e7221ed1d3d02989d5d01873323bac9f88459f26/README.md#just
"""
if function is not None:
return JustDispatch(function, upper_bound)
return partial(JustDispatch[Any], upper_bound=upper_bound)
170 changes: 46 additions & 124 deletions narwhals/dtypes/_supertyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from collections import deque
from itertools import chain, product
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Final, Generic, cast
from typing import TYPE_CHECKING, Any

from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
from narwhals._dispatch import just_dispatch
from narwhals._typing_compat import TypeVar
from narwhals.dtypes._classes import (
Array,
Expand Down Expand Up @@ -56,7 +57,6 @@

from typing_extensions import TypeAlias, TypeIs

from narwhals.dtypes import IntegerType
from narwhals.dtypes._classes import _Bits
from narwhals.typing import TimeUnit

Expand All @@ -72,33 +72,11 @@ def lru_cache(*, maxsize: int) -> Callable[[_Fn], _Fn]: # noqa: ARG001
else:
from functools import cache, lru_cache

Incomplete: TypeAlias = Any
FrozenDTypes: TypeAlias = frozenset[type[DType]]
DTypeGroup: TypeAlias = frozenset[type[DType]]
Nested: TypeAlias = "Array | List | Struct"
Parametric: TypeAlias = (
"Datetime | DatetimeV1 | Decimal |Duration | DurationV1 | Enum | Nested"
)
SameTemporalT = TypeVar("SameTemporalT", Datetime, DatetimeV1, Duration, DurationV1)
"""Temporal data types, with a `time_unit` attribute."""

SameDatetimeT = TypeVar("SameDatetimeT", Datetime, DatetimeV1)
SameT = TypeVar(
"SameT",
Array,
List,
Struct,
Datetime,
DatetimeV1,
Decimal,
Duration,
DurationV1,
Enum,
)
DTypeT1 = TypeVar("DTypeT1", bound=DType)
DTypeT2 = TypeVar("DTypeT2", bound=DType, default=DTypeT1)
DTypeT1_co = TypeVar("DTypeT1_co", bound=DType, covariant=True)
DTypeT2_co = TypeVar("DTypeT2_co", bound=DType, covariant=True, default=DTypeT1_co)


def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes:
Expand Down Expand Up @@ -145,14 +123,6 @@ def _key_fn_time_unit(obj: Datetime | Duration, /) -> int:
return _TIME_UNIT_PER_SECOND[obj.time_unit]


@lru_cache(maxsize=_CACHE_SIZE * 2)
def downcast_time_unit(
left: SameTemporalT, right: SameTemporalT, /
) -> SameTemporalT | None:
"""Return the operand with the lowest precision time unit."""
return min(left, right, key=_key_fn_time_unit)


@lru_cache(maxsize=_CACHE_SIZE // 2)
def dtype_eq(left: DType, right: DType, /) -> bool:
return left == right
Expand Down Expand Up @@ -216,6 +186,18 @@ def has_inner(dtype: Any) -> TypeIs[Array | List]:
return isinstance(dtype, (Array, List))


@just_dispatch(upper_bound=DType)
def same_supertype(left: DType, right: DType, /) -> DType | None:
return left if dtype_eq(left, right) else None


@same_supertype.register(Duration, DurationV1)
@lru_cache(maxsize=_CACHE_SIZE * 2)
def downcast_time_unit(left: SameTemporalT, right: SameTemporalT, /) -> SameTemporalT:
"""Return the operand with the lowest precision time unit."""
return min(left, right, key=_key_fn_time_unit)


def _struct_fields_union(
left: Collection[Field], right: Collection[Field], /
) -> Struct | None:
Expand All @@ -233,7 +215,8 @@ def _struct_fields_union(
return Struct(longest_map)


def _struct_supertype(left: Struct, right: Struct, /) -> Struct | None:
@same_supertype.register(Struct)
def struct_supertype(left: Struct, right: Struct, /) -> Struct | None:
"""Get the supertype of two struct data types.

Adapted from [`super_type_structs`](https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/utils/supertype.rs#L588-L603)
Expand All @@ -252,84 +235,48 @@ def _struct_supertype(left: Struct, right: Struct, /) -> Struct | None:
return Struct(new_fields)


def _array_supertype(left: Array, right: Array, /) -> Array | None:
@same_supertype.register(Array)
def array_supertype(left: Array, right: Array, /) -> Array | None:
if (left.shape == right.shape) and (
inner := get_supertype(left.inner(), right.inner())
):
return Array(inner, left.size)
return None


def _list_supertype(left: List, right: List, /) -> List | None:
@same_supertype.register(List)
def list_supertype(left: List, right: List, /) -> List | None:
if inner := get_supertype(left.inner(), right.inner()):
return List(inner)
return None


def _datetime_supertype(
@same_supertype.register(Datetime, DatetimeV1)
def datetime_supertype(
left: SameDatetimeT, right: SameDatetimeT, /
) -> SameDatetimeT | None:
if left.time_zone != right.time_zone:
return None
return downcast_time_unit(left, right)


def _enum_supertype(left: Enum, right: Enum, /) -> Enum | None:
@same_supertype.register(Enum)
def enum_supertype(left: Enum, right: Enum, /) -> Enum | None:
return left if left.categories == right.categories else None


def _decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal:
@same_supertype.register(Decimal)
def decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal:
# https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L508-L511
precision = max(left.precision, right.precision)
scale = max(left.scale, right.scale)
return Decimal(precision=precision, scale=scale)


_SAME_DISPATCH: Final[Mapping[type[Parametric], Callable[..., Incomplete | None]]] = {
Array: _array_supertype,
List: _list_supertype,
Struct: _struct_supertype,
Datetime: _datetime_supertype,
DatetimeV1: _datetime_supertype,
Duration: downcast_time_unit,
DurationV1: downcast_time_unit,
Enum: _enum_supertype,
Decimal: _decimal_supertype,
}
"""Specialized supertyping rules for `(T, T)`.

*When operands share the same class*, all other data types can use `DType.__eq__` (see [#3393]).

[#3393]: https://github.com/narwhals-dev/narwhals/pull/3393
"""


def is_single_base_type(
st: _SupertypeCase[DTypeT1, DType],
) -> TypeIs[_SupertypeCase[DTypeT1]]:
return len(st.base_types) == 1


def is_parametric_case(
st: _SupertypeCase[SameT | DType],
) -> TypeIs[_SupertypeCase[SameT]]:
return st.base_left in _SAME_DISPATCH


def _get_same_supertype_fn(base: type[SameT]) -> Callable[[SameT, SameT], SameT | None]:
return cast("Callable[[SameT, SameT], SameT | None]", _SAME_DISPATCH[base])


def _same_supertype(st: _SupertypeCase[SameT | DType]) -> SameT | DType | None:
if is_parametric_case(st):
return _get_same_supertype_fn(st.base_left)(st.left, st.right)
return st.left if dtype_eq(st.left, st.right) else None


DEC128_MAX_PREC = 38
# Precomputing powers of 10 up to 10^38
POW10_LIST = tuple(10**i for i in range(DEC128_MAX_PREC + 1))
INT_MAX_MAP: Mapping[IntegerType, int] = {
INT_MAX_MAP: Mapping[Int, int] = {
UInt8(): (2**8) - 1,
UInt16(): (2**16) - 1,
UInt32(): (2**32) - 1,
Expand All @@ -349,7 +296,7 @@ def _integer_fits_in_decimal(value: int, precision: int, scale: int) -> bool:
)


def _decimal_integer_supertyping(decimal: Decimal, integer: IntegerType) -> DType | None:
def _decimal_integer_supertyping(decimal: Decimal, integer: Int) -> DType | None:
precision, scale = decimal.precision, decimal.scale

if integer in {UInt128(), Int128()}:
Expand All @@ -364,15 +311,15 @@ def _decimal_integer_supertyping(decimal: Decimal, integer: IntegerType) -> DTyp
return Decimal(precision, scale)


@lru_cache(maxsize=_CACHE_SIZE)
def _numeric_supertype(st: _SupertypeCase[DType]) -> DType | None:
def _numeric_supertype(
left: DType, right: DType, base_types: FrozenDTypes
) -> DType | None:
"""Get the supertype of two numeric data types that do not share the same class.

`_{primitive_numeric,integer}_supertyping` define most valid numeric supertypes.

We generate these on first use, with all subsequent calls returning the same mapping.
"""
base_types = st.base_types
if NUMERIC.issuperset(base_types):
if INTEGER.issuperset(base_types):
return _integer_supertyping()[base_types]()
Expand All @@ -382,9 +329,7 @@ def _numeric_supertype(st: _SupertypeCase[DType]) -> DType | None:
# Logic adapted from rust implementation
# https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L517-L530
decimal, integer = (
(st.left, st.right)
if isinstance(st.left, Decimal)
else (st.right, st.left)
(left, right) if isinstance(left, Decimal) else (right, left)
)
return _decimal_integer_supertyping(decimal=decimal, integer=integer) # type: ignore[arg-type]

Expand All @@ -404,47 +349,24 @@ def _mixed_nested_supertype(left: DType, right: DType, /) -> DType | None:
return None


def _mixed_supertype(st: _SupertypeCase[DType, DType]) -> DType | None:
def _mixed_supertype(
left: DType, right: DType, base_types: FrozenDTypes, /
) -> DType | None:
"""Get the supertype of two data types that do not share the same class."""
base_types = st.base_types
left, right = st.left, st.right
if Date in base_types and _has_intersection(base_types, DATETIME):
return left if isinstance(left, Datetime) else right
if String in base_types:
return (Binary if Binary in base_types else String)()
if has_nested(base_types):
return _mixed_nested_supertype(left, right)
return _numeric_supertype(st) if _has_intersection(NUMERIC, base_types) else None


class _SupertypeCase(Generic[DTypeT1_co, DTypeT2_co]):
"""WIP."""

__slots__ = ("base_types", "left", "right")

left: DTypeT1_co
right: DTypeT2_co
base_types: frozenset[type[DTypeT1_co | DTypeT2_co]]

def __init__(self, left: DTypeT1_co, right: DTypeT2_co) -> None:
self.left = left
self.right = right
self.base_types = frozenset((self.base_left, self.base_right))

@property
def base_left(self) -> type[DTypeT1_co]:
return self.left.base_type()

@property
def base_right(self) -> type[DTypeT2_co]:
return self.right.base_type()
return (
_numeric_supertype(left, right, base_types)
if _has_intersection(NUMERIC, base_types)
else None
)


# NOTE @dangotbanned: Tried **many** variants of this typing
# (to self) DO NOT TOUCH IT AGAIN
def get_supertype(
left: DTypeT1, right: DTypeT2 | DType
) -> DTypeT1 | DTypeT2 | DType | None:
def get_supertype(left: DType, right: DType) -> DType | None:
"""Given two data types, determine the data type that both types can reasonably safely be cast to.

Arguments:
Expand All @@ -454,9 +376,9 @@ def get_supertype(
Returns:
The common supertype that both types can be safely cast to, or None if no such type exists.
"""
st_case = _SupertypeCase(left, right)
if Unknown in st_case.base_types:
base_types = frozen_dtypes(left.base_type(), right.base_type())
if Unknown in base_types:
return Unknown()
if is_single_base_type(st_case):
return _same_supertype(st_case)
return _mixed_supertype(st_case)
if len(base_types) == 1:
return same_supertype(left, right)
return _mixed_supertype(left, right, base_types)
Loading
Loading