diff --git a/.gitignore b/.gitignore index 31bcaee4b1..709e1fda52 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ site/ todo.md docs/this.md docs/api-completeness/*.md +docs/concepts/promotion-rules.md !docs/api-completeness/index.md # Lock files diff --git a/mkdocs.yml b/mkdocs.yml index aec01a801e..12e9c235dd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -19,6 +19,7 @@ nav: - concepts/column_names.md - concepts/boolean.md - concepts/null_handling.md + - concepts/promotion-rules.md - Overhead: overhead.md - Perfect backwards compatibility policy: backcompat.md - Extensions and Plugins: extending.md @@ -81,8 +82,9 @@ theme: favicon: assets/logo.svg logo: assets/logo.svg features: - - content.code.copy - content.code.annotate + - content.code.copy + - content.footnote.tooltips - navigation.footer - navigation.indexes - navigation.top @@ -126,13 +128,14 @@ plugins: hooks: - utils/generate_backend_completeness.py +- utils/generate_supertyping.py - utils/generate_zen_content.py - markdown_extensions: - admonition - md_in_html - pymdownx.details +- footnotes - pymdownx.tabbed: alternate_style: true - pymdownx.superfences: diff --git a/narwhals/_dispatch.py b/narwhals/_dispatch.py new file mode 100644 index 0000000000..b07702ce61 --- /dev/null +++ b/narwhals/_dispatch.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import partial +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, overload + +if TYPE_CHECKING: + + class Deferred(Protocol): + def __call__(self, f: Callable[..., R], /) -> JustDispatch[R]: ... + + +__all__ = ["just_dispatch"] + +R = TypeVar("R") +R_co = TypeVar("R_co", covariant=True) +PassthroughFn = TypeVar("PassthroughFn", bound=Callable[..., Any]) +"""Original function is passed-through unchanged.""" + + +class JustDispatch(Generic[R_co]): + """Single-dispatch wrapper produced by decorating a function with `@just_dispatch`.""" + + __slots__ = ("_registry", "_upper_bound") + + def __init__(self, function: Callable[..., R_co], /, upper_bound: type[Any]) -> None: + self._upper_bound: type[Any] = upper_bound + self._registry: dict[type[Any], Callable[..., R_co]] = {upper_bound: function} + + def dispatch(self, tp: type[Any], /) -> Callable[..., R_co]: + """Get the implementation for a given type.""" + if f := self._registry.get(tp): + return f + if issubclass(tp, self._upper_bound): + f = self._registry[tp] = self._registry[self._upper_bound] + return f + msg = f"{self._registry[self._upper_bound].__name__!r} does not support {tp.__name__!r}" + raise TypeError(msg) + + def register( + self, tp: type[Any], *tps: type[Any] + ) -> Callable[[PassthroughFn], PassthroughFn]: + """Register types to dispatch via the decorated function.""" + + def decorate(f: PassthroughFn, /) -> PassthroughFn: + self._registry.update((tp_, f) for tp_ in (tp, *tps)) + return f + + return decorate + + def __call__(self, arg: object, *args: Any, **kwds: Any) -> R_co: + """Dispatch on the type of the first argument, passing through all arguments.""" + return self.dispatch(arg.__class__)(arg, *args, **kwds) + + +@overload +def just_dispatch(function: Callable[..., R], /) -> JustDispatch[R]: ... +@overload +def just_dispatch(*, upper_bound: type[Any] = object) -> Deferred: ... +def just_dispatch( + function: Callable[..., R] | None = None, /, *, upper_bound: type[Any] = object +) -> JustDispatch[R] | Deferred: + """Transform a function into a single-dispatch generic function. + + An alternative take on [`@functools.singledispatch`]: + - without [MRO] fallback + - allows [*just*] the types registered and optionally an `upper_bound` + + Arguments: + function: Function to decorate, where the body serves as the default implementation. + upper_bound: When there is no registered implementation for a specific type, it must + be a subclass of `upper_bound` to use the default implementation. + + Tip: + `@just_dispatch` should only be used to decorate **internal functions** as we lose the docstring. + + [`@functools.singledispatch`]: https://docs.python.org/3/library/functools.html#functools.singledispatch + [MRO]: https://docs.python.org/3/howto/mro.html#python-2-3-mro + [*just*]: https://github.com/jorenham/optype/blob/e7221ed1d3d02989d5d01873323bac9f88459f26/README.md#just + """ + if function is not None: + return JustDispatch(function, upper_bound) + return partial(JustDispatch[Any], upper_bound=upper_bound) diff --git a/narwhals/dtypes/__init__.py b/narwhals/dtypes/__init__.py new file mode 100644 index 0000000000..6f42cc5294 --- /dev/null +++ b/narwhals/dtypes/__init__.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from narwhals.dtypes._classes import ( + Array, + Binary, + Boolean, + Categorical, + Date, + Datetime, + Decimal, + DType, + Duration, + Enum, + Field, + Float32, + Float64, + FloatType, + Int8, + Int16, + Int32, + Int64, + Int128, + IntegerType, + List, + NestedType, + NumericType, + Object, + SignedIntegerType, + String, + Struct, + TemporalType, + Time, + UInt8, + UInt16, + UInt32, + UInt64, + UInt128, + Unknown, + UnsignedIntegerType, +) + +__all__ = [ + "Array", + "Binary", + "Boolean", + "Categorical", + "DType", + "Date", + "Datetime", + "Decimal", + "Duration", + "Enum", + "Field", + "Float32", + "Float64", + "FloatType", + "Int8", + "Int16", + "Int32", + "Int64", + "Int128", + "IntegerType", + "List", + "NestedType", + "NumericType", + "Object", + "SignedIntegerType", + "String", + "Struct", + "TemporalType", + "Time", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "UInt128", + "Unknown", + "UnsignedIntegerType", +] diff --git a/narwhals/dtypes.py b/narwhals/dtypes/_classes.py similarity index 94% rename from narwhals/dtypes.py rename to narwhals/dtypes/_classes.py index 587b6c2758..6df97df20d 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes/_classes.py @@ -2,61 +2,24 @@ import enum from collections import OrderedDict -from collections.abc import Iterable, Mapping +from collections.abc import Mapping from datetime import timezone from itertools import starmap -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar -from narwhals._utils import ( - _DeferredIterable, - isinstance_or_issubclass, - qualified_type_name, -) +from narwhals._utils import _DeferredIterable, isinstance_or_issubclass from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Sequence - from typing import Any + from collections.abc import Iterable, Iterator, Sequence + from typing import Any, Literal import _typeshed - from typing_extensions import Self, TypeIs + from typing_extensions import Self, TypeAlias from narwhals.typing import IntoDType, TimeUnit - -def _validate_dtype(dtype: DType | type[DType]) -> None: - if not isinstance_or_issubclass(dtype, DType): - msg = ( - f"Expected Narwhals dtype, got: {type(dtype)}.\n\n" - "Hint: if you were trying to cast to a type, use e.g. nw.Int64 instead of 'int64'." - ) - raise TypeError(msg) - - -def _is_into_dtype(obj: Any) -> TypeIs[IntoDType]: - return isinstance(obj, DType) or ( - isinstance(obj, DTypeClass) and not issubclass(obj, NestedType) - ) - - -def _is_nested_type(obj: Any) -> TypeIs[type[NestedType]]: - return isinstance(obj, DTypeClass) and issubclass(obj, NestedType) - - -def _validate_into_dtype(dtype: Any) -> None: - if not _is_into_dtype(dtype): - if _is_nested_type(dtype): - name = f"nw.{dtype.__name__}" - msg = ( - f"{name!r} is not valid in this context.\n\n" - f"Hint: instead of:\n\n" - f" {name}\n\n" - "use:\n\n" - f" {name}(...)" - ) - else: - msg = f"Expected Narwhals dtype, got: {qualified_type_name(dtype)!r}." - raise TypeError(msg) + _Bits: TypeAlias = Literal[8, 16, 32, 64, 128] class DTypeClass(type): @@ -176,6 +139,9 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] def __hash__(self) -> int: return hash(self.__class__) + def __call__(self) -> Self: + return self + class NumericType(DType): """Base class for numeric data types.""" @@ -184,12 +150,19 @@ class NumericType(DType): class IntegerType(NumericType): """Base class for integer data types.""" + # NOTE: Likely going to need an `Integer` metaclass, to be able to use `Final` or a class property + _bits: ClassVar[_Bits] + + def __init_subclass__(cls, *args: Any, bits: _Bits, **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + cls._bits = bits + -class SignedIntegerType(IntegerType): +class SignedIntegerType(IntegerType, bits=128): """Base class for signed integer data types.""" -class UnsignedIntegerType(IntegerType): +class UnsignedIntegerType(IntegerType, bits=128): """Base class for unsigned integer data types.""" @@ -261,7 +234,7 @@ def __repr__(self) -> str: # pragma: no cover return f"{class_name}(precision={self.precision!r}, scale={self.scale!r})" -class Int128(SignedIntegerType): +class Int128(SignedIntegerType, bits=128): """128-bit signed integer type. Examples: @@ -281,7 +254,7 @@ class Int128(SignedIntegerType): """ -class Int64(SignedIntegerType): +class Int64(SignedIntegerType, bits=64): """64-bit signed integer type. Examples: @@ -294,7 +267,7 @@ class Int64(SignedIntegerType): """ -class Int32(SignedIntegerType): +class Int32(SignedIntegerType, bits=32): """32-bit signed integer type. Examples: @@ -307,7 +280,7 @@ class Int32(SignedIntegerType): """ -class Int16(SignedIntegerType): +class Int16(SignedIntegerType, bits=16): """16-bit signed integer type. Examples: @@ -320,7 +293,7 @@ class Int16(SignedIntegerType): """ -class Int8(SignedIntegerType): +class Int8(SignedIntegerType, bits=8): """8-bit signed integer type. Examples: @@ -333,7 +306,7 @@ class Int8(SignedIntegerType): """ -class UInt128(UnsignedIntegerType): +class UInt128(UnsignedIntegerType, bits=128): """128-bit unsigned integer type. Examples: @@ -347,7 +320,7 @@ class UInt128(UnsignedIntegerType): """ -class UInt64(UnsignedIntegerType): +class UInt64(UnsignedIntegerType, bits=64): """64-bit unsigned integer type. Examples: @@ -360,7 +333,7 @@ class UInt64(UnsignedIntegerType): """ -class UInt32(UnsignedIntegerType): +class UInt32(UnsignedIntegerType, bits=32): """32-bit unsigned integer type. Examples: @@ -373,7 +346,7 @@ class UInt32(UnsignedIntegerType): """ -class UInt16(UnsignedIntegerType): +class UInt16(UnsignedIntegerType, bits=16): """16-bit unsigned integer type. Examples: @@ -386,7 +359,7 @@ class UInt16(UnsignedIntegerType): """ -class UInt8(UnsignedIntegerType): +class UInt8(UnsignedIntegerType, bits=8): """8-bit unsigned integer type. Examples: diff --git a/narwhals/stable/v1/_dtypes.py b/narwhals/dtypes/_classes_v1.py similarity index 63% rename from narwhals/stable/v1/_dtypes.py rename to narwhals/dtypes/_classes_v1.py index 5b4ea54958..960c34d68e 100644 --- a/narwhals/stable/v1/_dtypes.py +++ b/narwhals/dtypes/_classes_v1.py @@ -3,43 +3,12 @@ from typing import TYPE_CHECKING from narwhals._utils import inherit_doc -from narwhals.dtypes import ( - Array, - Binary, - Boolean, - Categorical, - Date, +from narwhals.dtypes._classes import ( Datetime as NwDatetime, - Decimal, DType, DTypeClass, Duration as NwDuration, Enum as NwEnum, - Field, - Float32, - Float64, - FloatType, - Int8, - Int16, - Int32, - Int64, - Int128, - IntegerType, - List, - NestedType, - NumericType, - Object, - SignedIntegerType, - String, - Struct, - Time, - UInt8, - UInt16, - UInt32, - UInt64, - UInt128, - Unknown, - UnsignedIntegerType, ) if TYPE_CHECKING: @@ -47,6 +16,8 @@ from narwhals.typing import TimeUnit +__all__ = ["Datetime", "Duration", "Enum"] + class Datetime(NwDatetime): @inherit_doc(NwDatetime) @@ -95,42 +66,3 @@ def __hash__(self) -> int: # pragma: no cover def __repr__(self) -> str: # pragma: no cover return super(NwEnum, self).__repr__() - - -__all__ = [ - "Array", - "Binary", - "Boolean", - "Categorical", - "DType", - "Date", - "Datetime", - "Decimal", - "Duration", - "Enum", - "Field", - "Float32", - "Float64", - "FloatType", - "Int8", - "Int16", - "Int32", - "Int64", - "Int128", - "IntegerType", - "List", - "NestedType", - "NumericType", - "Object", - "SignedIntegerType", - "String", - "Struct", - "Time", - "UInt8", - "UInt16", - "UInt32", - "UInt64", - "UInt128", - "Unknown", - "UnsignedIntegerType", -] diff --git a/narwhals/dtypes/_supertyping.py b/narwhals/dtypes/_supertyping.py new file mode 100644 index 0000000000..d534f2bf93 --- /dev/null +++ b/narwhals/dtypes/_supertyping.py @@ -0,0 +1,384 @@ +"""Rules for safe type promotion. + +Follows a subset of `polars`' [`get_supertype_with_options`]. + +See [Data type promotion rules] for an in-depth explanation. + +[`get_supertype_with_options`]: https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L142-L543 +[Data type promotion rules]: https://narwhals-dev.github.io/narwhals/concepts/promotion-rules/ +""" + +from __future__ import annotations + +from collections import deque +from itertools import chain, product +from operator import attrgetter +from typing import TYPE_CHECKING, Any + +from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND +from narwhals._dispatch import just_dispatch +from narwhals._typing_compat import TypeVar +from narwhals.dtypes._classes import ( + Array, + Binary, + Boolean, + Date, + Datetime, + Decimal, + DType, + Duration, + Enum, + Field, + Float32, + Float64, + FloatType as Float, + Int8, + Int16, + Int32, + Int64, + Int128, + IntegerType as Int, + List, + SignedIntegerType, + String, + Struct, + UInt8, + UInt16, + UInt32, + UInt64, + UInt128, + Unknown, + UnsignedIntegerType, +) +from narwhals.dtypes._classes_v1 import Datetime as DatetimeV1, Duration as DurationV1 + +if TYPE_CHECKING: + from collections.abc import Callable, Collection, Mapping + + from typing_extensions import TypeAlias, TypeIs + + from narwhals.dtypes._classes import _Bits + from narwhals.typing import TimeUnit + + _Fn = TypeVar("_Fn", bound=Callable[..., Any]) + + # NOTE: Hack to make `functools.cache` *not* negatively impact typing + def cache(fn: _Fn, /) -> _Fn: + return fn + + # NOTE: Double hack + pretends `maxsize` is keyword-only and has no default + def lru_cache(*, maxsize: int) -> Callable[[_Fn], _Fn]: # noqa: ARG001 + return cache +else: + from functools import cache, lru_cache + +FrozenDTypes: TypeAlias = frozenset[type[DType]] +DTypeGroup: TypeAlias = frozenset[type[DType]] +SameTemporalT = TypeVar("SameTemporalT", Datetime, DatetimeV1, Duration, DurationV1) +"""Temporal data types, with a `time_unit` attribute.""" +SameDatetimeT = TypeVar("SameDatetimeT", Datetime, DatetimeV1) + + +def frozen_dtypes(*dtypes: type[DType]) -> FrozenDTypes: + """Alternative `frozenset` constructor. + + Gets `mypy` to stop inferring a more precise type (that later becomes incompatible). + """ + return frozenset(dtypes) + + +_CACHE_SIZE = 32 +"""Arbitrary size (currently). + +- 27 concrete `DType` classes +- 3 (V1) subclasses +- Pairwise comparisons, but order (of classes) is not important +""" + + +SIGNED_INTEGER: DTypeGroup = frozenset((Int8, Int16, Int32, Int64, Int128)) +UNSIGNED_INTEGER: DTypeGroup = frozenset((UInt8, UInt16, UInt32, UInt64, UInt128)) +INTEGER: DTypeGroup = SIGNED_INTEGER.union(UNSIGNED_INTEGER) +FLOAT: DTypeGroup = frozenset((Float32, Float64)) +NUMERIC: DTypeGroup = FLOAT.union(INTEGER).union((Decimal,)) +NESTED: DTypeGroup = frozenset((Struct, List, Array)) +DATETIME: DTypeGroup = frozen_dtypes(Datetime, DatetimeV1) + +_FLOAT_PROMOTE: Mapping[FrozenDTypes, type[Float64]] = { + frozen_dtypes(Float32, Float64): Float64, + frozen_dtypes(Decimal, Float64): Float64, + frozen_dtypes(Decimal, Float32): Float64, +} + + +_TIME_UNIT_PER_SECOND: Mapping[TimeUnit, int] = { + "s": 1, + "ms": MS_PER_SECOND, + "us": US_PER_SECOND, + "ns": NS_PER_SECOND, +} + + +def _key_fn_time_unit(obj: Datetime | Duration, /) -> int: + return _TIME_UNIT_PER_SECOND[obj.time_unit] + + +@lru_cache(maxsize=_CACHE_SIZE // 2) +def dtype_eq(left: DType, right: DType, /) -> bool: + return left == right + + +@cache +def _integer_supertyping() -> Mapping[FrozenDTypes, type[Int | Float64]]: + """Generate the supertype conversion table for all integer data type pairs.""" + tps_int = SignedIntegerType.__subclasses__() + tps_uint = UnsignedIntegerType.__subclasses__() + get_bits: attrgetter[_Bits] = attrgetter("_bits") + ints = ( + (frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits)) + for lhs, rhs in product(tps_int, repeat=2) + ) + uints = ( + (frozen_dtypes(lhs, rhs), max(lhs, rhs, key=get_bits)) + for lhs, rhs in product(tps_uint, repeat=2) + ) + # NOTE: `Float64` is here because `mypy` refuses to respect the last overload 😭 + # https://github.com/python/typeshed/blob/a564787bf23386e57338b750bf4733f3c978b701/stdlib/typing.pyi#L776-L781 + ubits_to_int: Mapping[_Bits, type[Int | Float64]] = {8: Int16, 16: Int32, 32: Int64} + mixed = ( + ( + frozen_dtypes(int_, uint), + int_ if int_._bits > uint._bits else ubits_to_int.get(uint._bits, Float64), + ) + for int_, uint in product(tps_int, tps_uint) + ) + return dict(chain(ints, uints, mixed)) + + +@cache +def _primitive_numeric_supertyping() -> Mapping[FrozenDTypes, type[Float]]: + """Generate the supertype conversion table for all (integer, float) data type pairs.""" + F32, F64 = Float32, Float64 # noqa: N806 + small_int = (Int8, Int16, UInt8, UInt16) + small_int_f32 = ((frozen_dtypes(tp, F32), F32) for tp in small_int) + big_int_f32 = ((frozen_dtypes(tp, F32), F64) for tp in INTEGER.difference(small_int)) + int_f64 = ((frozen_dtypes(tp, F64), F64) for tp in INTEGER) + return dict(chain(small_int_f32, big_int_f32, int_f64)) + + +def _first_excluding(base_types: FrozenDTypes, *exclude: type[DType]) -> type[DType]: + """Return an arbitrary element from base_types excluding the given types.""" + others = base_types.difference(exclude) + return next(iter(others)) + + +def _has_intersection(a: frozenset[Any], b: frozenset[Any], /) -> bool: + """Return True if sets share at least one element.""" + return not a.isdisjoint(b) + + +@lru_cache(maxsize=_CACHE_SIZE) +def has_nested(base_types: FrozenDTypes, /) -> bool: + return _has_intersection(base_types, NESTED) + + +def has_inner(dtype: Any) -> TypeIs[Array | List]: + return isinstance(dtype, (Array, List)) + + +@just_dispatch(upper_bound=DType) +def same_supertype(left: DType, right: DType, /) -> DType | None: + return left if dtype_eq(left, right) else None + + +@same_supertype.register(Duration, DurationV1) +@lru_cache(maxsize=_CACHE_SIZE * 2) +def downcast_time_unit(left: SameTemporalT, right: SameTemporalT, /) -> SameTemporalT: + """Return the operand with the lowest precision time unit.""" + return min(left, right, key=_key_fn_time_unit) + + +def _struct_fields_union( + left: Collection[Field], right: Collection[Field], / +) -> Struct | None: + """Adapted from [`union_struct_fields`](https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/utils/supertype.rs#L559-L586).""" + longest, shortest = (left, right) if len(left) >= len(right) else (right, left) + longest_map = {f.name: f.dtype() for f in longest} + for f in shortest: + name, dtype = f.name, f.dtype() + dtype_longest = longest_map.setdefault(name, dtype) + if not dtype_eq(dtype, dtype_longest): + if supertype := get_supertype(dtype, dtype_longest): + longest_map[name] = supertype + else: + return None + return Struct(longest_map) + + +@same_supertype.register(Struct) +def struct_supertype(left: Struct, right: Struct, /) -> Struct | None: + """Get the supertype of two struct data types. + + Adapted from [`super_type_structs`](https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/utils/supertype.rs#L588-L603) + """ + left_fields, right_fields = left.fields, right.fields + if len(left_fields) != len(right_fields): + return _struct_fields_union(left_fields, right_fields) + new_fields = deque["Field"]() + for left_f, right_f in zip(left_fields, right_fields): + if left_f.name != right_f.name: + return _struct_fields_union(left_fields, right_fields) + if supertype := get_supertype(left_f.dtype(), right_f.dtype()): + new_fields.append(Field(left_f.name, supertype)) + else: + return None + return Struct(new_fields) + + +@same_supertype.register(Array) +def array_supertype(left: Array, right: Array, /) -> Array | None: + if (left.shape == right.shape) and ( + inner := get_supertype(left.inner(), right.inner()) + ): + return Array(inner, left.size) + return None + + +@same_supertype.register(List) +def list_supertype(left: List, right: List, /) -> List | None: + if inner := get_supertype(left.inner(), right.inner()): + return List(inner) + return None + + +@same_supertype.register(Datetime, DatetimeV1) +def datetime_supertype( + left: SameDatetimeT, right: SameDatetimeT, / +) -> SameDatetimeT | None: + if left.time_zone != right.time_zone: + return None + return downcast_time_unit(left, right) + + +@same_supertype.register(Enum) +def enum_supertype(left: Enum, right: Enum, /) -> Enum | None: + return left if left.categories == right.categories else None + + +@same_supertype.register(Decimal) +def decimal_supertype(left: Decimal, right: Decimal, /) -> Decimal: + # https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L508-L511 + precision = max(left.precision, right.precision) + scale = max(left.scale, right.scale) + return Decimal(precision=precision, scale=scale) + + +DEC128_MAX_PREC = 38 +# Precomputing powers of 10 up to 10^38 +POW10_LIST = tuple(10**i for i in range(DEC128_MAX_PREC + 1)) +INT_MAX_MAP: Mapping[Int, int] = { + UInt8(): (2**8) - 1, + UInt16(): (2**16) - 1, + UInt32(): (2**32) - 1, + UInt64(): (2**64) - 1, + Int8(): (2**7) - 1, + Int16(): (2**15) - 1, + Int32(): (2**31) - 1, + Int64(): (2**63) - 1, +} + + +def _integer_fits_in_decimal(value: int, precision: int, scale: int) -> bool: + """Scales an integer and checks if it fits the target precision.""" + # !NOTE: Indexing is safe since `scale <= precision <= 38` + return (precision == DEC128_MAX_PREC) or ( + value * POW10_LIST[scale] < POW10_LIST[precision] + ) + + +def _decimal_integer_supertyping(decimal: Decimal, integer: Int) -> DType | None: + precision, scale = decimal.precision, decimal.scale + + if integer in {UInt128(), Int128()}: + fits_orig_prec_scale = False + elif value := INT_MAX_MAP.get(integer, None): + fits_orig_prec_scale = _integer_fits_in_decimal(value, precision, scale) + else: # pragma: no cover + msg = "Unreachable integer type" + raise ValueError(msg) + + precision = precision if fits_orig_prec_scale else DEC128_MAX_PREC + return Decimal(precision, scale) + + +def _numeric_supertype( + left: DType, right: DType, base_types: FrozenDTypes +) -> DType | None: + """Get the supertype of two numeric data types that do not share the same class. + + `_{primitive_numeric,integer}_supertyping` define most valid numeric supertypes. + + We generate these on first use, with all subsequent calls returning the same mapping. + """ + if NUMERIC.issuperset(base_types): + if INTEGER.issuperset(base_types): + return _integer_supertyping()[base_types]() + if tp := _FLOAT_PROMOTE.get(base_types): + return tp() + if Decimal in base_types: + # Logic adapted from rust implementation + # https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L517-L530 + decimal, integer = ( + (left, right) if isinstance(left, Decimal) else (right, left) + ) + return _decimal_integer_supertyping(decimal=decimal, integer=integer) # type: ignore[arg-type] + + return _primitive_numeric_supertyping()[base_types]() + if Boolean in base_types: + return _first_excluding(base_types, Boolean)() + return None + + +def _mixed_nested_supertype(left: DType, right: DType, /) -> DType | None: + if ( + has_inner(left) + and has_inner(right) + and (inner := get_supertype(left.inner(), right.inner())) + ): + return List(inner) + return None + + +def _mixed_supertype( + left: DType, right: DType, base_types: FrozenDTypes, / +) -> DType | None: + """Get the supertype of two data types that do not share the same class.""" + if Date in base_types and _has_intersection(base_types, DATETIME): + return left if isinstance(left, Datetime) else right + if String in base_types: + return (Binary if Binary in base_types else String)() + if has_nested(base_types): + return _mixed_nested_supertype(left, right) + return ( + _numeric_supertype(left, right, base_types) + if _has_intersection(NUMERIC, base_types) + else None + ) + + +def get_supertype(left: DType, right: DType) -> DType | None: + """Given two data types, determine the data type that both types can reasonably safely be cast to. + + Arguments: + left: First data type. + right: Second data type. + + Returns: + The common supertype that both types can be safely cast to, or None if no such type exists. + """ + base_types = frozen_dtypes(left.base_type(), right.base_type()) + if Unknown in base_types: + return Unknown() + if len(base_types) == 1: + return same_supertype(left, right) + return _mixed_supertype(left, right, base_types) diff --git a/narwhals/dtypes/_utils.py b/narwhals/dtypes/_utils.py new file mode 100644 index 0000000000..9265a3fa5d --- /dev/null +++ b/narwhals/dtypes/_utils.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from narwhals._utils import isinstance_or_issubclass, qualified_type_name +from narwhals.dtypes._classes import DType, DTypeClass, NestedType + +if TYPE_CHECKING: + from typing import Any + + from typing_extensions import TypeIs + + from narwhals.typing import IntoDType + + +def validate_dtype(dtype: DType | type[DType]) -> None: + if not isinstance_or_issubclass(dtype, DType): + msg = ( + f"Expected Narwhals dtype, got: {type(dtype)}.\n\n" + "Hint: if you were trying to cast to a type, use e.g. nw.Int64 instead of 'int64'." + ) + raise TypeError(msg) + + +def is_into_dtype(obj: Any) -> TypeIs[IntoDType]: + return isinstance(obj, DType) or ( + isinstance(obj, DTypeClass) and not issubclass(obj, NestedType) + ) + + +def is_nested_type(obj: Any) -> TypeIs[type[NestedType]]: + return isinstance(obj, DTypeClass) and issubclass(obj, NestedType) + + +def validate_into_dtype(dtype: Any) -> None: + if not is_into_dtype(dtype): + if is_nested_type(dtype): + name = f"nw.{dtype.__name__}" + msg = ( + f"{name!r} is not valid in this context.\n\n" + f"Hint: instead of:\n\n" + f" {name}\n\n" + "use:\n\n" + f" {name}(...)" + ) + else: + msg = f"Expected Narwhals dtype, got: {qualified_type_name(dtype)!r}." + raise TypeError(msg) diff --git a/narwhals/expr.py b/narwhals/expr.py index 5f48162bbe..97cc04c36c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -12,7 +12,7 @@ no_default, unstable, ) -from narwhals.dtypes import _validate_dtype +from narwhals.dtypes._utils import validate_dtype from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr_cat import ExprCatNamespace from narwhals.expr_dt import ExprDateTimeNamespace @@ -186,7 +186,7 @@ def cast(self, dtype: IntoDType) -> Self: | 2 3.0 8 | └──────────────────┘ """ - _validate_dtype(dtype) + validate_dtype(dtype) return self._append_node(ExprNode(ExprKind.ELEMENTWISE, "cast", dtype=dtype)) # --- binary --- diff --git a/narwhals/series.py b/narwhals/series.py index 0ce582e496..c6f6a867a7 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -30,7 +30,7 @@ unstable, ) from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar -from narwhals.dtypes import _validate_dtype, _validate_into_dtype +from narwhals.dtypes._utils import validate_dtype, validate_into_dtype from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.expr import Expr from narwhals.functions import col @@ -171,7 +171,7 @@ def from_numpy( msg = "`from_numpy` only accepts 1D numpy arrays" raise ValueError(msg) if dtype: - _validate_into_dtype(dtype) + validate_into_dtype(dtype) implementation = Implementation.from_backend(backend) if is_eager_allowed(implementation): ns = cls._version.namespace.from_backend(implementation).compliant @@ -230,7 +230,7 @@ def from_iterable( if is_numpy_array(values): return cls.from_numpy(name, values, dtype, backend=backend) if dtype: - _validate_into_dtype(dtype) + validate_into_dtype(dtype) if not isinstance(values, Iterable): msg = f"Expected values to be an iterable, got: {qualified_type_name(values)!r}." raise TypeError(msg) @@ -621,7 +621,7 @@ def cast(self, dtype: IntoDType) -> Self: ] ] """ - _validate_dtype(dtype) + validate_dtype(dtype) return self._with_compliant(self._compliant_series.cast(dtype)) def to_frame(self) -> DataFrame[Any]: diff --git a/narwhals/stable/v1/dtypes.py b/narwhals/stable/v1/dtypes.py index a292be8a60..553dbf9c33 100644 --- a/narwhals/stable/v1/dtypes.py +++ b/narwhals/stable/v1/dtypes.py @@ -1,16 +1,13 @@ from __future__ import annotations -from narwhals.stable.v1._dtypes import ( +from narwhals.dtypes._classes import ( Array, Binary, Boolean, Categorical, Date, - Datetime, Decimal, DType, - Duration, - Enum, Field, Float32, Float64, @@ -37,6 +34,7 @@ Unknown, UnsignedIntegerType, ) +from narwhals.dtypes._classes_v1 import Datetime, Duration, Enum __all__ = [ "Array", diff --git a/tests/conftest.py b/tests/conftest.py index 3e80bcdff4..bd814df6c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,15 @@ import pytest import narwhals as nw +from narwhals import dtypes from narwhals._utils import Implementation, generate_temporary_column_name -from tests.utils import ID_PANDAS_LIKE, PANDAS_VERSION, pyspark_session, sqlframe_session +from tests.utils import ( + ID_PANDAS_LIKE, + PANDAS_VERSION, + dtype_ids, + pyspark_session, + sqlframe_session, +) if TYPE_CHECKING: from collections.abc import Sequence @@ -371,7 +378,7 @@ def eager_implementation(request: pytest.FixtureRequest) -> EagerAllowed: nw.Unknown, nw.Binary, ], - ids=lambda tp: tp.__name__, + ids=dtype_ids, ) def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: tp_dtype: type[NonNestedDType] = request.param @@ -385,8 +392,44 @@ def non_nested_type(request: pytest.FixtureRequest) -> type[NonNestedDType]: nw.Struct({"a": nw.Boolean}), nw.Enum(["beluga", "narwhal"]), ], - ids=lambda obj: type(obj).__name__, + ids=dtype_ids, ) def nested_dtype(request: pytest.FixtureRequest) -> NestedOrEnumDType: dtype: NestedOrEnumDType = request.param return dtype + + +@pytest.fixture( + params=[ + nw.Decimal(), + *dtypes.SignedIntegerType.__subclasses__(), + *dtypes.UnsignedIntegerType.__subclasses__(), + *dtypes.FloatType.__subclasses__(), + ], + ids=dtype_ids, +) +def numeric_dtype(request: pytest.FixtureRequest) -> dtypes.NumericType: + dtype: dtypes.NumericType = request.param + return dtype + + +@pytest.fixture( + params=[ + nw.Time(), + nw.Date(), + nw.Datetime(), + nw.Datetime("s"), + nw.Datetime("ns"), + nw.Datetime("us"), + nw.Datetime("ms"), + nw.Duration("s"), + nw.Duration("ns"), + nw.Duration("us"), + nw.Duration("ms"), + ], + ids=dtype_ids, +) +def naive_temporal_dtype(request: pytest.FixtureRequest) -> dtypes.TemporalType: + """All `TemporalType`s in `nw.dtypes`, excluding `time_zone` info.""" + dtype: dtypes.TemporalType = request.param + return dtype diff --git a/tests/dispatch_test.py b/tests/dispatch_test.py new file mode 100644 index 0000000000..2745ba6bf4 --- /dev/null +++ b/tests/dispatch_test.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import decimal +import string + +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals._dispatch import JustDispatch, just_dispatch +from narwhals.dtypes import DType, SignedIntegerType + + +def type_name(obj: object) -> str: + return type(obj).__name__ + + +@pytest.fixture +def dispatch_no_bound() -> JustDispatch[str]: + @just_dispatch + def dtype_repr_code(dtype: DType) -> str: + return type_name(dtype).lower() + + return dtype_repr_code + + +@pytest.fixture +def dispatch_upper_bound() -> JustDispatch[str]: + @just_dispatch(upper_bound=DType) + def dtype_repr_code(dtype: DType) -> str: + return type_name(dtype).lower() + + return dtype_repr_code + + +@pytest.fixture +def stdlib_decimal() -> decimal.Decimal: + return decimal.Decimal("1.0") + + +def test_just_dispatch( + dispatch_no_bound: JustDispatch[str], stdlib_decimal: decimal.Decimal +) -> None: + i64 = nw.Int64() + assert dispatch_no_bound(i64) == "int64" + + @dispatch_no_bound.register(*SignedIntegerType.__subclasses__()) + def repr_int(dtype: SignedIntegerType) -> str: + return f"i{type_name(dtype).strip(string.ascii_letters)}" + + assert dispatch_no_bound(i64) == "i64" + assert repr_int(i64) == "i64" + assert dispatch_no_bound(nw.UInt8()) == "uint8" + assert dispatch_no_bound(stdlib_decimal) == "decimal" + + +def test_just_dispatch_upper_bound( + dispatch_upper_bound: JustDispatch[str], stdlib_decimal: decimal.Decimal +) -> None: + i64 = nw.Int64() + assert dispatch_upper_bound(i64) == "int64" + + @dispatch_upper_bound.register(*SignedIntegerType.__subclasses__()) + def repr_int(dtype: SignedIntegerType) -> str: + return f"i{type_name(dtype).strip(string.ascii_letters)}" + + assert dispatch_upper_bound(i64) == "i64" + assert repr_int(i64) == "i64" + assert dispatch_upper_bound(nw.UInt8()) == "uint8" + assert dispatch_upper_bound(nw.Decimal()) == "decimal" + + with pytest.raises(TypeError, match=r"'dtype_repr_code' does not support 'Decimal'"): + dispatch_upper_bound(stdlib_decimal) + + dispatch_upper_bound.register(type(stdlib_decimal))(lambda _: "need to be explicit") + assert dispatch_upper_bound(stdlib_decimal) == "need to be explicit" + + +def test_just_dispatch_no_auto_subclass(dispatch_no_bound: JustDispatch[str]) -> None: + NOT_REGISTERED = "datetime" # noqa: N806 + TZ = "Africa/Accra" # noqa: N806 + + assert dispatch_no_bound(nw.Datetime("ms")) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime("us")) == NOT_REGISTERED + + @dispatch_no_bound.register(nw.Datetime) + def repr_datetime(dtype: nw.Datetime) -> str: + if dtype.time_zone is None: + args: str = dtype.time_unit + else: + args = f"{dtype.time_unit}, {dtype.time_zone}" + return f"datetime[{args}]" + + assert dispatch_no_bound(nw.Datetime()) == "datetime[us]" + assert dispatch_no_bound(nw.Datetime("s")) == "datetime[s]" + assert dispatch_no_bound(nw.Datetime(time_zone=TZ)) == f"datetime[us, {TZ}]" + + assert dispatch_no_bound(nw_v1.Datetime()) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime("s")) == NOT_REGISTERED + assert dispatch_no_bound(nw_v1.Datetime(time_zone=TZ)) == NOT_REGISTERED + + dispatch_no_bound.register(nw_v1.Datetime)(repr_datetime) + + assert dispatch_no_bound(nw_v1.Datetime()) == "datetime[us]" + assert dispatch_no_bound(nw_v1.Datetime("s")) == "datetime[s]" + assert dispatch_no_bound(nw_v1.Datetime(time_zone=TZ)) == f"datetime[us, {TZ}]" diff --git a/tests/dtypes/get_supertype_test.py b/tests/dtypes/get_supertype_test.py new file mode 100644 index 0000000000..cc84b811b1 --- /dev/null +++ b/tests/dtypes/get_supertype_test.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals.dtypes._supertyping import get_supertype +from tests.utils import dtype_ids + +if TYPE_CHECKING: + from collections.abc import Mapping + + from typing_extensions import TypeAlias + + from narwhals.dtypes import DType, NumericType, TemporalType + from narwhals.typing import IntoDType + + IntoStruct: TypeAlias = Mapping[str, IntoDType] + + +def _check_supertype(left: DType, right: DType, expected: DType | None) -> None: + result = get_supertype(left, right) + if expected is None: + assert result is None + else: + assert result is not None + assert result == expected + + +@pytest.mark.parametrize( + "dtype", + [ + nw.Array(nw.Binary(), shape=(2,)), + nw.Array(nw.Boolean, shape=(2, 3)), + nw.Binary(), + nw.Boolean(), + nw.Categorical(), + nw.Date(), + nw.Datetime(), + nw.Datetime(time_unit="ns", time_zone="Europe/Berlin"), + nw.Decimal(), + nw.Duration(), + nw.Enum(["orca", "narwhal"]), + nw.Enum([]), + nw.Float32(), + nw.Float64(), + nw.Int8(), + nw.Int16(), + nw.Int32(), + nw.Int64(), + nw.Int128(), + nw.List(nw.String), + nw.List(nw.Array(nw.Int32, shape=(5, 3))), + nw.Object(), + nw.String(), + nw.Struct({"r2": nw.Float64(), "mse": nw.Float32()}), + nw.Struct({"a": nw.String, "b": nw.List(nw.Int32)}), + nw.Time(), + nw.UInt8(), + nw.UInt16(), + nw.UInt32(), + nw.UInt64(), + nw.UInt128(), + nw.Unknown(), + ], + ids=dtype_ids, +) +def test_identical_dtype(dtype: DType) -> None: + _check_supertype(dtype, dtype, dtype) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + (nw.Datetime("ns"), nw.Datetime("us"), nw.Datetime("us")), + (nw.Datetime("s"), nw.Datetime("us"), nw.Datetime("s")), + (nw.Datetime("s"), nw.Datetime("s", "Africa/Accra"), None), + (nw.Datetime(time_zone="Asia/Kathmandu"), nw.Datetime(), None), + ( + nw.Enum(["beluga", "narwhal", "orca"]), + nw.Enum(["dog", "cat", "fish with legs"]), + None, + ), + (nw.Enum([]), nw.Enum(["fruit", "other food"]), None), + (nw.List(nw.Int64), nw.List(nw.Int64()), nw.List(nw.Int64())), + (nw.List(nw.UInt16()), nw.List(nw.Int32), nw.List(nw.Int32())), + (nw.List(nw.Date), nw.List(nw.Binary), None), + (nw.List(nw.Unknown), nw.List(nw.Float64), nw.List(nw.Unknown())), + ( + nw.Array(nw.Float32, shape=2), + nw.Array(nw.Float64, shape=2), + nw.Array(nw.Float64, shape=2), + ), + (nw.Array(nw.Int64, shape=1), nw.Array(nw.Int64, shape=4), None), + ( + nw.Array(nw.Decimal, shape=1), + nw.Array(nw.Unknown, shape=1), + nw.Array(nw.Unknown, shape=1), + ), + ( + nw.Array(nw.UInt128, shape=3), + nw.Array(nw.Decimal, shape=3), + nw.Array(nw.Decimal, shape=3), + ), + ( + nw.Array(nw.String, shape=1), + nw.Array(nw.Int64, shape=1), + nw.Array(nw.String, shape=1), + ), + ], + ids=dtype_ids, +) +def test_same_class(left: DType, right: DType, expected: DType | None) -> None: + _check_supertype(left, right, expected) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + ( + {"f0": nw.Duration("ms"), "f1": nw.Int64, "f2": nw.Int64}, + {"f0": nw.Duration("us"), "f1": nw.Int64()}, + {"f0": nw.Duration("ms"), "f1": nw.Int64(), "f2": nw.Int64()}, + ), + ( + {"f0": nw.Float64, "f1": nw.Date, "f2": nw.Int32}, + {"f0": nw.Float32, "f1": nw.Datetime, "f3": nw.UInt8}, + {"f0": nw.Float64, "f1": nw.Datetime(), "f2": nw.Int32, "f3": nw.UInt8}, + ), + ( + {"f0": nw.Int32, "f1": nw.Boolean, "f2": nw.String}, + {"f0": nw.Unknown}, + {"f0": nw.Unknown, "f1": nw.Boolean, "f2": nw.String}, + ), + ( + {"f0": nw.Object, "f1": nw.List(nw.Boolean)}, + {"f0": nw.List(nw.Boolean), "f1": nw.List(nw.Boolean)}, + None, + ), + ({"f0": nw.Binary()}, {"f0": nw.Datetime("s"), "f1": nw.Date}, None), + ( + {"f0": nw.Int64, "f1": nw.Struct({"f1": nw.Float32, "f0": nw.String})}, + { + "f0": nw.UInt8, + "f1": nw.Struct( + {"f0": nw.Categorical, "f1": nw.Float64(), "f2": nw.Time} + ), + }, + { + "f0": nw.Int64, + "f1": nw.Struct({"f0": nw.String, "f1": nw.Float64, "f2": nw.Time}), + }, + ), + ( + {"F0": nw.UInt8, "f0": nw.Int16}, + {"f0": nw.Int128, "f1": nw.UInt16, " f0": nw.Int8}, + {"f0": nw.Int128, "f1": nw.UInt16, " f0": nw.Int8, "F0": nw.UInt8}, + ), + ], + ids=dtype_ids, +) +def test_struct(left: IntoStruct, right: IntoStruct, expected: IntoStruct | None) -> None: + expected_ = None if expected is None else nw.Struct(expected) + _check_supertype(nw.Struct(left), nw.Struct(right), expected_) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + (nw.Datetime("ns"), nw.Date(), nw.Datetime("ns")), + (nw.Date(), nw.Datetime(), nw.Datetime()), + (nw.Datetime(), nw.Int8(), None), + (nw.String(), nw.Categorical(), nw.String()), + (nw.Enum(["hello"]), nw.Categorical(), None), + (nw.Enum(["hello"]), nw.String(), nw.String()), + (nw.Binary(), nw.String(), nw.Binary()), + ], + ids=dtype_ids, +) +def test_mixed_dtype(left: DType, right: DType, expected: DType | None) -> None: + _check_supertype(left, right, expected) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + # Same depth) + (nw.List(nw.Int64), nw.Array(nw.Int32, shape=2), nw.List(nw.Int64())), + (nw.List(nw.Float32), nw.Array(nw.Float64, shape=3), nw.List(nw.Float64())), + ( + nw.List(nw.List(nw.Int8)), + nw.Array(nw.Int16, shape=(2, 3)), + nw.List(nw.List(nw.Int16())), + ), + (nw.List(nw.String), nw.Array(nw.Int64, shape=2), nw.List(nw.String())), + # Incompatible inner types + (nw.List(nw.Categorical), nw.Array(nw.Int64, shape=2), None), + # Depth mismatch + (nw.List(nw.Int64), nw.Array(nw.Int32, shape=(2, 3)), None), + (nw.List(nw.List(nw.Int64)), nw.Array(nw.Int32, shape=2), None), + ], + ids=dtype_ids, +) +def test_list_array_supertype(left: DType, right: DType, expected: DType | None) -> None: + _check_supertype(left, right, expected) + _check_supertype(right, left, expected) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + # {X, String} -> String conversions + (nw.Int64(), nw.String(), nw.String()), + (nw.Float32(), nw.String(), nw.String()), + (nw.Boolean(), nw.String(), nw.String()), + (nw.Date(), nw.String(), nw.String()), + (nw.Datetime(), nw.String(), nw.String()), + (nw.Duration(), nw.String(), nw.String()), + (nw.List(nw.Int64), nw.String(), nw.String()), + (nw.Array(nw.Int64, shape=2), nw.String(), nw.String()), + (nw.Struct({"a": nw.Int64}), nw.String(), nw.String()), + (nw.Decimal(), nw.String(), nw.String()), + (nw.Time(), nw.String(), nw.String()), + # Binary + String -> Binary (exception to the rule) + (nw.Binary(), nw.String(), nw.Binary()), + ], + ids=dtype_ids, +) +def test_string_supertype(left: DType, right: DType, expected: DType) -> None: + _check_supertype(left, right, expected) + _check_supertype(right, left, expected) + + +def test_mixed_integer_temporal( + naive_temporal_dtype: TemporalType, numeric_dtype: NumericType +) -> None: + _check_supertype(naive_temporal_dtype, numeric_dtype, None) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + # NOTE: The order of the case *should not* matter (some are flipped for coverage) + # signed + signed + (nw.Int8(), nw.Int16(), nw.Int16()), + (nw.Int8(), nw.Int32(), nw.Int32()), + (nw.Int8(), nw.Int64(), nw.Int64()), + (nw.Int8(), nw.Int128(), nw.Int128()), + (nw.Int16(), nw.Int32(), nw.Int32()), + (nw.Int64(), nw.Int16(), nw.Int64()), + (nw.Int16(), nw.Int128(), nw.Int128()), + (nw.Int64(), nw.Int32(), nw.Int64()), + (nw.Int32(), nw.Int128(), nw.Int128()), + (nw.Int64(), nw.Int128(), nw.Int128()), + # unsigned + unsigned + (nw.UInt8(), nw.UInt16(), nw.UInt16()), + (nw.UInt32(), nw.UInt8(), nw.UInt32()), + (nw.UInt8(), nw.UInt64(), nw.UInt64()), + (nw.UInt8(), nw.UInt128(), nw.UInt128()), + (nw.UInt16(), nw.UInt32(), nw.UInt32()), + (nw.UInt16(), nw.UInt64(), nw.UInt64()), + (nw.UInt128(), nw.UInt16(), nw.UInt128()), + (nw.UInt32(), nw.UInt64(), nw.UInt64()), + (nw.UInt32(), nw.UInt128(), nw.UInt128()), + (nw.UInt64(), nw.UInt128(), nw.UInt128()), + # signed + unsigned + (nw.Int8(), nw.UInt8(), nw.Int16()), + (nw.Int8(), nw.UInt16(), nw.Int32()), + (nw.UInt32(), nw.Int8(), nw.Int64()), + (nw.Int8(), nw.UInt64(), nw.Float64()), + (nw.Int16(), nw.UInt8(), nw.Int16()), + (nw.Int16(), nw.UInt16(), nw.Int32()), + (nw.UInt32(), nw.Int16(), nw.Int64()), + (nw.Int16(), nw.UInt64(), nw.Float64()), + (nw.Int32(), nw.UInt8(), nw.Int32()), + (nw.UInt16(), nw.Int32(), nw.Int32()), + (nw.Int32(), nw.UInt32(), nw.Int64()), + (nw.Int32(), nw.UInt64(), nw.Float64()), + (nw.Int64(), nw.UInt8(), nw.Int64()), + (nw.UInt16(), nw.Int64(), nw.Int64()), + (nw.Int64(), nw.UInt32(), nw.Int64()), + (nw.Int64(), nw.UInt64(), nw.Float64()), + # float + float + (nw.Float32(), nw.Float64(), nw.Float64()), + (nw.Float64(), nw.Float32(), nw.Float64()), + # float + integer + (nw.Int8(), nw.Float32(), nw.Float32()), + (nw.Int16(), nw.Float32(), nw.Float32()), + (nw.Float32(), nw.UInt8(), nw.Float32()), + (nw.Float32(), nw.UInt16(), nw.Float32()), + (nw.Int32(), nw.Float32(), nw.Float64()), + (nw.Int64(), nw.Float32(), nw.Float64()), + (nw.UInt32(), nw.Float32(), nw.Float64()), + (nw.Float32(), nw.UInt64(), nw.Float64()), + (nw.Int8(), nw.Float64(), nw.Float64()), + (nw.Float64(), nw.Int64(), nw.Float64()), + # float + decimal + (nw.Decimal(), nw.Float32(), nw.Float64()), + (nw.Decimal(), nw.Float64(), nw.Float64()), + # decimal + decimal + (nw.Decimal(5, 2), nw.Decimal(4, 3), nw.Decimal(5, 3)), + (nw.Decimal(scale=12), nw.Decimal(18, scale=9), nw.Decimal(38, 12)), + # decimal + integer + (nw.Decimal(4, 1), nw.UInt8(), nw.Decimal(4, 1)), + (nw.Decimal(5, 2), nw.Int8(), nw.Decimal(5, 2)), + (nw.Decimal(10, 0), nw.Int32(), nw.Decimal(10, 0)), + (nw.Decimal(15, 2), nw.UInt32(), nw.Decimal(15, 2)), + (nw.Decimal(2, 1), nw.UInt8, nw.Decimal(38, 1)), + (nw.Decimal(10, 5), nw.Int64, nw.Decimal(38, 5)), + (nw.Decimal(38, 0), nw.Int128, nw.Decimal(38, 0)), + (nw.Decimal(1, 0), nw.UInt8(), nw.Decimal(38, 0)), + (nw.Decimal(38, 38), nw.Int8(), nw.Decimal(38, 38)), + (nw.Decimal(10, 1), nw.UInt32(), nw.Decimal(38, 1)), + ], + ids=dtype_ids, +) +def test_numeric_promotion(left: DType, right: DType, expected: DType) -> None: + _check_supertype(left, right, expected) + _check_supertype(right, left, expected) + + +def test_numeric_and_bool_promotion(numeric_dtype: NumericType) -> None: + _check_supertype(numeric_dtype, nw.Boolean(), numeric_dtype) + _check_supertype(nw.Boolean(), numeric_dtype, numeric_dtype) + + +@pytest.mark.parametrize( + ("left", "right", "expected"), + [ + (nw_v1.Datetime(), nw_v1.Datetime(), nw_v1.Datetime()), + (nw_v1.Datetime("ns"), nw_v1.Datetime("s"), nw_v1.Datetime("s")), + ( + nw_v1.Datetime(time_zone="Europe/Berlin"), + nw_v1.Datetime(time_zone="Europe/Berlin"), + nw_v1.Datetime(time_zone="Europe/Berlin"), + ), + ( + nw_v1.Datetime(time_zone="Europe/Berlin"), + nw_v1.Datetime("ms", "Europe/Berlin"), + nw_v1.Datetime("ms", "Europe/Berlin"), + ), + (nw_v1.Datetime(time_zone="Europe/Berlin"), nw_v1.Datetime(), None), + (nw_v1.Datetime("s"), nw_v1.Datetime("s", "Africa/Accra"), None), + (nw_v1.Duration("ns"), nw_v1.Duration("ms"), nw_v1.Duration("ms")), + (nw_v1.Duration(), nw_v1.Duration(), nw_v1.Duration()), + (nw_v1.Duration("s"), nw_v1.Duration(), nw_v1.Duration("s")), + (nw_v1.Duration(), nw_v1.Datetime(), None), + (nw_v1.Enum(), nw_v1.Enum(), nw_v1.Enum()), + (nw_v1.Enum(), nw_v1.String(), nw_v1.String()), + ( + nw.Date(), + nw_v1.Datetime(time_zone="Europe/Berlin"), + nw_v1.Datetime(time_zone="Europe/Berlin"), + ), + ( + nw.Struct({"f0": nw_v1.Duration("ms"), "f1": nw.Int64, "f2": nw.Int64}), + nw.Struct({"f0": nw_v1.Duration("us"), "f1": nw.Int64()}), + nw.Struct({"f0": nw_v1.Duration("ms"), "f1": nw.Int64(), "f2": nw.Int64()}), + ), + ( + nw.Struct({"f0": nw.Float64, "f1": nw.Date, "f2": nw.Int32}), + nw.Struct({"f0": nw.Float32, "f1": nw_v1.Datetime, "f3": nw.UInt8}), + nw.Struct( + {"f0": nw.Float64, "f1": nw_v1.Datetime(), "f2": nw.Int32, "f3": nw.UInt8} + ), + ), + ( + nw.Struct({"f0": nw.Binary()}), + nw.Struct({"f0": nw_v1.Datetime("s"), "f1": nw.Date}), + None, + ), + ( + nw.Array(nw.Date, shape=3), + nw.Array(nw_v1.Datetime, shape=3), + nw.Array(nw_v1.Datetime, shape=3), + ), + (nw.Array(nw.Date, shape=1), nw.Array(nw_v1.Datetime, shape=3), None), + (nw.Array(nw.Date, shape=2), nw.Array(nw_v1.Duration, shape=2), None), + ( + nw.Array(nw_v1.Enum(), shape=4), + nw.Array(nw_v1.Enum(), shape=4), + nw.Array(nw_v1.Enum(), shape=4), + ), + ], + ids=dtype_ids, +) +def test_v1_dtypes(left: DType, right: DType, expected: DType | None) -> None: + result = get_supertype(left, right) + if expected is None: + assert result is None + else: + assert result is not None + assert result == expected + # Must also preserve v1-ness + assert type(result) is type(expected) + + +@pytest.mark.parametrize( + ("dtype_v1", "dtype_main"), + [ + (nw_v1.Duration("ms"), nw.Duration("ms")), + (nw_v1.Duration("ns"), nw.Duration("ns")), + (nw_v1.Datetime(time_unit="ms"), nw.Datetime(time_unit="ms")), + (nw_v1.Datetime(time_zone="Europe/Rome"), nw.Datetime(time_zone="Europe/Rome")), + (nw_v1.Enum(), nw.Enum([])), + ], +) +def test_mixed_versions_return_none(dtype_v1: DType, dtype_main: DType) -> None: + _check_supertype(dtype_v1, dtype_main, None) + _check_supertype(dtype_main, dtype_v1, None) diff --git a/tests/utils.py b/tests/utils.py index 4d01223b2a..4401d47b3f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,8 +11,10 @@ import pytest import narwhals as nw +import narwhals.stable.v1 as nw_v1 from narwhals._utils import Implementation, parse_version, zip_strict from narwhals.dependencies import get_pandas +from narwhals.dtypes import DType from narwhals.translate import from_native if TYPE_CHECKING: @@ -263,6 +265,32 @@ def time_unit_compat(time_unit: TimeUnit, request: pytest.FixtureRequest, /) -> return time_unit +# TODO @dangotbanned: Add a real doc with examples +def dtype_ids(obj: DType | type[DType] | None) -> str: # noqa: PLR0911 + """Some tweaks to `DType.__repr__` for more readable test ids.""" + if obj is None: + return str(obj) + if isinstance(obj, DType): + if obj.__slots__: + if isinstance(obj, nw.Datetime): + return f"Datetime[{obj.time_unit}, {obj.time_zone}]" + if isinstance(obj, nw.Duration): + return f"Duration[{obj.time_unit}]" + if isinstance(obj, nw.Enum): + if isinstance(obj, nw_v1.Enum): + return "v1.Enum[]" # pragma: no cover + return f"Enum{list(obj.categories)!r}" + if isinstance(obj, nw.Array): + dtype: Any = obj + for _ in obj.shape: + dtype = dtype.inner + return f"Array[{dtype!r}, {obj.shape}]" + # non-empty slots == parameters + return repr(obj) + return obj.__class__.__name__ + return repr(obj) + + def is_pyspark_connect(constructor: Constructor) -> bool: is_spark_connect = bool(os.environ.get("SPARK_CONNECT", None)) return is_spark_connect and ("pyspark" in str(constructor)) diff --git a/utils/generate_supertyping.py b/utils/generate_supertyping.py new file mode 100644 index 0000000000..52b4a07372 --- /dev/null +++ b/utils/generate_supertyping.py @@ -0,0 +1,95 @@ +"""Adapted from [@FBruzzesi script (2026-01-11)]. + +[@FBruzzesi script (2026-01-11)]: https://github.com/narwhals-dev/narwhals/pull/3396#issuecomment-3733465005 +""" + +from __future__ import annotations + +from itertools import product +from pathlib import Path +from typing import Final, TypeVar + +import polars as pl +from jinja2 import Template + +from narwhals.dtypes import DType, Enum, Unknown +from narwhals.dtypes._supertyping import get_supertype + +T = TypeVar("T") + +TEMPLATE_PATH: Final[Path] = Path("utils") / "promotion-rules.md.jinja" +DESTINATION_PATH: Final[Path] = Path("docs") / "concepts" / "promotion-rules.md" + + +def get_leaf_subclasses(cls: type[T]) -> list[type[T]]: + """Get all leaf subclasses (classes with no further subclasses).""" + leaves = [] + for subclass in cls.__subclasses__(): + if subclass.__subclasses__(): # Has children, recurse + leaves.extend(get_leaf_subclasses(subclass)) + else: # No children, it's a "leaf" + leaves.append(subclass) + return leaves + + +def collect_supertypes() -> None: + from narwhals.dtypes import _classes as _classes, _classes_v1 as _classes_v1 # noqa: I001, PLC0414 + + dtypes = get_leaf_subclasses(DType) + supertypes: list[tuple[str, str, str]] = [] + for left, right in product(dtypes, dtypes): + promoted: str + base_types = frozenset((left, right)) + left_str, right_str = str(left), str(right) + + if Unknown in base_types: + promoted = str(Unknown) + elif left is right: + promoted = str(left) + elif left.is_nested() or right.is_nested(): + promoted = "" + else: + if left is Enum: + left = Enum(["tmp"]) # noqa: PLW2901 + if right is Enum: + right = Enum(["tmp"]) # noqa: PLW2901 + + _promoted = get_supertype(left(), right()) + promoted = str(_promoted.__class__) if _promoted else "" + + supertypes.append((left_str, right_str, promoted)) + + frame = ( + pl.DataFrame(supertypes, schema=["_", "right", "supertype"], orient="row") + .pivot( + index="_", + on="right", + values="supertype", + aggregate_function=None, + sort_columns=True, + ) + .sort("_") + .with_columns(_=pl.format("**{}**", pl.col("_"))) + .rename({"_": ""}) + ) + + with ( + pl.Config( + tbl_rows=30, + tbl_cols=30, + tbl_formatting="MARKDOWN", + tbl_hide_column_data_types=True, + tbl_hide_dataframe_shape=True, + tbl_cell_alignment="LEFT", + tbl_width_chars=-1, + ), + TEMPLATE_PATH.open(mode="r") as stream, + DESTINATION_PATH.open(mode="w", encoding="utf-8", newline="\n") as file, + ): + content = Template(stream.read()).render({"promotion_rules_table": str(frame)}) + + file.write(content) + file.write("\n") + + +collect_supertypes() diff --git a/utils/promotion-rules.md.jinja b/utils/promotion-rules.md.jinja new file mode 100644 index 0000000000..963e19b20a --- /dev/null +++ b/utils/promotion-rules.md.jinja @@ -0,0 +1,411 @@ +# Data type promotion rules + +When combining columns of different data types (e.g., in `concat(..., how="vertical_relaxed")`), +Narwhals determines a common **supertype**[^1], the most specific type that both can safely be cast to. + +This page documents the rules used to derive that supertype. + +The implementation aims to follow the rules defined by +[Polars' `get_supertype_with_options`][get_supertype_with_options]. + +!!! tip + + If you are in a hurry, we have a single view to check all the combination that we handle. + Skip to the following section: [Everything in a single table](#everything-in-a-single-table) + +```python exec="1" session="promotion-rules" +import narwhals as nw +from narwhals.dtypes._supertyping import get_supertype +def st(left, right): + s_left, s_right = f"nw.{left!r}", f"nw.{right!r}" + reprs = (s if s.endswith(")") else s + "()" for s in (s_left, s_right)) + args = ", ".join(reprs) + print(f"get_supertype({args}) == {get_supertype(left, right)}") +``` + +## Same Type + +When both operands share the same [`base_type`][base-type], the behavior depends on whether the type is +**parametric**[^2]. + +* **Non-parametric types** (e.g., `Int32`, `Float64`, `String`, `Boolean`): the supertype is the type itself if both + are equal, otherwise no supertype exists. +* For **parametric types** (such as `Datetime`, `Duration`, `List`, `Struct`, `Enum`) the equality of the base type is + not sufficient — the parameters must also be compatible. + We have specialized rules for each parametric type described in the following sections. + +## Numeric Types + +Numeric supertyping follows a hierarchy designed to preserve precision while avoiding overflow. + +### Integer with Integer (same sign) + +When both operands are signed integers or both are unsigned integers, the supertype is the one with +the **higher bit-width**. for example: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Int8(), nw.Int32()) +st(nw.UInt16(), nw.UInt64()) +``` + +### Integer with Integer (mixed sign) + +When mixing signed and unsigned integers: + +* If the signed integer has a **strictly higher** bit-width than the unsigned, + use the signed type: + + ```python exec="1" session="promotion-rules" result="python" + st(nw.UInt8(), nw.Int16()) + st(nw.UInt16(), nw.Int32()) + ``` + +* Otherwise, promote the unsigned to the **next higher** signed bit-width (up to `Int64`): + + ```python exec="1" session="promotion-rules" result="python" + st(nw.UInt8(), nw.Int8()) + st(nw.UInt16(), nw.Int16()) + st(nw.UInt32(), nw.Int32()) + ``` + +* For `UInt64` or `UInt128` mixed with any signed integer, the supertype is `Float64` + (since no signed integer type can safely represent the full range of `UInt64`/`UInt128`): + + ```python exec="1" session="promotion-rules" result="python" + st(nw.UInt64(), nw.Int8()) + st(nw.UInt128(), nw.Int16()) + st(nw.UInt64(), nw.Int32()) + ``` + +### Integer with Float + +* Small integers (`Int8`, `Int16`, `UInt8`, `UInt16`) combined with `Float32` are promoted to `Float32` + + ```python exec="1" session="promotion-rules" result="python" + st(nw.Int8(), nw.Float32()) + st(nw.Int16(), nw.Float32()) + st(nw.UInt8(), nw.Float32()) + st(nw.UInt16(), nw.Float32()) + ``` + +* Larger integers (`Int32`, `Int64`, `Int128`, `UInt32`, `UInt64`, `UInt128`) combined with `Float32` + are promoted to `Float64` + + ```python exec="1" session="promotion-rules" result="python" + st(nw.Int32(), nw.Float32()) + st(nw.Int64(), nw.Float32()) + st(nw.Int128(), nw.Float32()) + st(nw.UInt32(), nw.Float32()) + st(nw.UInt64(), nw.Float32()) + st(nw.UInt128(), nw.Float32()) + ``` + +* Any integer combined with `Float64` is promoted to `Float64`: + + ```python exec="1" session="promotion-rules" result="python" + st(nw.Int32(), nw.Float64()) + st(nw.UInt32(), nw.Float64()) + st(nw.UInt64(), nw.Float64()) + ``` + +### Float with Float + +The combination of a `Float32` with a `Float64` is promoted to `Float64`: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Float32(), nw.Float64()) +st(nw.Float64(), nw.Float32()) +``` + +### Decimal + +* The combination of a `Decimal` with an integer is promoted to `Decimal`: + + ```python exec="1" session="promotion-rules" result="python" + st(nw.Decimal(), nw.Int32()) + st(nw.Decimal(), nw.Int64()) + ``` + +* The combination of a `Decimal` with a `Float32` or `Float64` is promoted to `Float64`: + + ```python exec="1" session="promotion-rules" result="python" + st(nw.Decimal(), nw.Float32()) + st(nw.Decimal(), nw.Float64()) + ``` + + + +## Temporal Types + +### Duration + +Two `Duration` types always have a supertype, namely the type with the **less precise** (coarser) time unit. +For example: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Duration('us'), nw.Duration('ms')) +st(nw.Duration('s'), nw.Duration('ms')) +``` + +Time unit precision order (from coarsest to finest): `s` < `ms` < `us` < `ns` + +### Datetime + +Two `Datetime` types have a supertype only if they share the **same time zone**: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Datetime('us'), nw.Datetime('ns')) + +tz = "Europe/Berlin" +print(f"{tz = !r}") +st(nw.Datetime(time_zone=tz), nw.Datetime(time_zone=tz)) +``` + +The resulting time unit is the **less precise** (coarser) of the two as defined in the previous section on `Duration`. + +If they do not share the same time zone, no supertype exists: + +```python exec="1" session="promotion-rules" result="python" +tz1 = "Europe/Berlin" +tz2 = "Europe/Paris" +print(f"{tz1 = !r}") +print(f"{tz2 = !r}") +st(nw.Datetime(time_zone=tz1), nw.Datetime(time_zone=tz2)) +``` + +### Datetime and Date + +When combining a `Datetime` with a `Date`, the supertype is the `Datetime` (preserving its time unit and time zone). + +```python exec="1" session="promotion-rules" result="python" +st(nw.Datetime('us'), nw.Date()) +st(nw.Datetime('ms', time_zone='Europe/Berlin'), nw.Date()) +``` + +## String-like Types + +Any type combined with `String` is promoted to `String`, except for `Binary`: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Int64(), nw.String()) +st(nw.Float32(), nw.String()) +st(nw.Boolean(), nw.String()) +st(nw.Date(), nw.String()) +st(nw.Datetime(), nw.String()) +st(nw.String(), nw.Categorical()) +st(nw.String(), nw.Enum(['orca', 'narwhal'])) +st(nw.List(nw.Int64()), nw.String()) +st(nw.Array(nw.Int64(), shape=2), nw.String()) +st(nw.Struct({"a": nw.Int64()}), nw.String()) +``` + +The combination of a `String` with a `Binary` is promoted to `Binary`: + +```python exec="1" session="promotion-rules" result="python" +st(nw.String(), nw.Binary()) +``` + +### Binary + +The combination of a `Binary` with a `String` is promoted to `Binary`: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Binary(), nw.String()) +``` + +All other combinations involving `Binary` have no supertype. + +### Enum + +Two `Enum` types have a supertype only if they share the **exact same categories**. +In that case, the supertype is the `Enum` itself. + +```python exec="1" session="promotion-rules" result="python" +enum1 = nw.Enum(["orca", "narwhal"]) +enum2 = nw.Enum(["orca", "beluga"]) +print(f"{enum1 = !r}") +st(enum1, enum1) + +print(f"{enum2 = !r}") +st(enum1, enum2) +``` + +## Other non-nested types + +### Boolean + +A `Boolean` combined with any numeric type is promoted to that numeric type: + +```python exec="1" session="promotion-rules" result="python" +st(nw.Boolean(), nw.Int32()) +st(nw.Boolean(), nw.Float64()) +``` + +### Unknown + +If either operand is `Unknown`, the result is always `Unknown`. This is a fast-path that short-circuits all other logic. + +## Nested Types + +Nested types (`Array`, `List`, `Struct`) generally require both operands to share the **same base type**. +The supertype is then determined by recursively resolving the inner types. + +An exception is the combination of `List` and `Array`, which can be promoted to `List` if they have the same depth. + +### List + +The supertype is a `List` whose inner type is the supertype of the two inner types. + +```python exec="1" session="promotion-rules" result="python" +left = nw.List(nw.UInt8()) +right = nw.List(nw.Int16()) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +If the inner types cannot be promoted, no supertype exists. + +```python exec="1" session="promotion-rules" result="python" +left = nw.List(nw.Int8()) +right = nw.List(nw.Categorical()) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +### Array + +Array supertyping extends the rules followed by `List` supertyping as both arrays must have the **same shape**. +If they do, the supertype is an `Array` whose inner type is the supertype of the two inner types, +and the shape is the same as the operands: + +```python exec="1" session="promotion-rules" result="python" +left = nw.Array(nw.Float32(), shape=(2,)) +right = nw.Array(nw.Int32(), shape=(2,)) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) + +``` + +If the shapes are different, no supertype exists. + +```python exec="1" session="promotion-rules" result="python" +left = nw.Array(nw.Int8(), shape=(2, )) +right = nw.Array(nw.Int8(), shape=(2, 3)) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +### List and Array + +When combining a `List` with an `Array`, the supertype is a `List` if both have the **same depth** +(nesting level). The inner type is the supertype of the two inner types. + +```python exec="1" session="promotion-rules" result="python" +left = nw.List(nw.Int32()) +right = nw.Array(nw.Int64(), shape=(2,)) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +This also works for nested structures with matching depth: + +```python exec="1" session="promotion-rules" result="python" +left = nw.List(nw.List(nw.Int8())) +right = nw.Array(nw.UInt8(), shape=(2, 3)) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +If the depths don't match, no supertype exists: + +```python exec="1" session="promotion-rules" result="python" +left = nw.List(nw.Int64()) +right = nw.Array(nw.Int32(), shape=(2, 3)) + +print(f"{left = !r}") +print(f"{right = !r}") +st(left, right) +``` + +### Struct + +Struct supertyping is more flexible and the **order of operands matters**: + +* If both structs have the **same number of fields** and **matching field names** (in order), the supertype is a + `Struct` where each field's dtype is the supertype of the corresponding fields. + + ```python exec="1" session="promotion-rules" result="python" + left = nw.Struct({"f0": nw.Int8(), "f1": nw.Int32()}) + right = nw.Struct({"f0": nw.UInt8(), "f1": nw.Int32()}) + + print(f"{left = !r}") + print(f"{right = !r}") + st(left, right) + ``` + +* If structs have positionally **mismatched field names**, a union is performed: + the result contains all fields from both structs. For fields that appear in both (by name), + their dtypes must have a valid supertype. + + ```python exec="1" session="promotion-rules" result="python" + left = nw.Struct({"f0": nw.Int8(), "f1": nw.Int32()}) + right = nw.Struct({"f0": nw.UInt8(), "f2": nw.Int64()}) + + print(f"{left = !r}") + print(f"{right = !r}") + st(left, right) + ``` + +* The `left` operand defines the field order of the output, *unless* `right` has more fields. + + ```python exec="1" session="promotion-rules" result="python" + left = nw.Struct({"f1": nw.Int32(), "f0": nw.Int8()}) + right = nw.Struct({"f0": nw.UInt8(), "f2": nw.Int64(), "f1": nw.Int32()}) + + print(f"{left = !r}") + print(f"{right = !r}") + st(left, right) + ``` + +## No Supertype + +The following combinations have **no valid supertype** and will result in `None`: + +* `Struct` combined with non-nested types (note: `List` and `Array` can be combined with `String`) +* `Datetime` values with different time zones +* `Enum` values with different categories +* `Array` values with different shapes +* `List` and `Array` with different depths +* Temporal types combined with numeric types (unlike Polars, see [narwhals-dev/narwhals#121]) +* Any other unlisted combination of different base types + +## Everything in a single table + +{{ promotion_rules_table }} + +[get_supertype_with_options]: https://github.com/pola-rs/polars/blob/529f7ec642912a2f15656897d06f1532c2f5d4c4/crates/polars-core/src/utils/supertype.rs#L142-L543 +[base-type]: ../api-reference/dtypes.md#narwhals.dtypes.DType.base_type +[narwhals-dev/narwhals#121]: https://github.com/narwhals-dev/narwhals/issues/121 +[^1]: Given two data types `A` and `B`, their supertype `S` is the smallest (most specific) + type such that both `A` and `B` can be losslessly cast to without risk of overflow, + truncation, or loss of precision. If no such type exists, the supertype is `None`. +[^2]: Types that are parameterized by additional metadata beyond their base classification. + Two values of the same parametric type may still be incompatible if their parameters differ. + For example, `Datetime` is parameterized by `time_unit` and `time_zone`; + `List` and `Array` by their inner element type; `Struct` by its field names and types; + and `Enum` by its set of categories.