diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index b433bd6fd5..d8e1958c40 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -175,6 +175,34 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] def __hash__(self) -> int: return hash(self.__class__) + def __call__(self) -> Self: + # NOTE: (Internal doc) + # But, why? + # + # Let's say we have a function that looks like this: + # + # >>> import pyarrow as pa + # >>> from narwhals.dtypes import DType + # >>> from narwhals.typing import NonNestedDType + # + # >>> def convert(dtype: DType | type[NonNestedDType]) -> pa.DataType: ... + # + # + # If we need to resolve the union to a *class*, we can use: + # + # >>> always_class = dtype.base_type() + # + # But what if instead, we need an *instance*?: + # + # >>> always_instance_1 = dtype if isinstance(dtype, DType) else dtype() + # >>> always_instance_2 = dtype() if isinstance(dtype, type) else dtype + # + # By defining `__call__`, we can save an instance check and keep things readable: + # + # >>> always_instance = dtype() + # >>> always_class = dtype.base_type() + return self + class NumericType(DType): """Base class for numeric data types.""" diff --git a/tests/dtypes/dtypes_test.py b/tests/dtypes/dtypes_test.py index 7fb5cc29f2..ba0e9629df 100644 --- a/tests/dtypes/dtypes_test.py +++ b/tests/dtypes/dtypes_test.py @@ -8,13 +8,14 @@ import pytest import narwhals as nw +from narwhals.dtypes import DType from narwhals.exceptions import PerformanceWarning from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION, pyspark_session if TYPE_CHECKING: from collections.abc import Iterable - from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType + from narwhals.typing import IntoDType, IntoFrame, IntoSeries, NonNestedDType from tests.utils import Constructor, ConstructorPandasLike, NestedOrEnumDType @@ -585,6 +586,52 @@ def test_dtype_base_type_nested(nested_dtype: NestedOrEnumDType) -> None: assert base.base_type() == nested_dtype.base_type() +def test_dtype___call___(non_nested_type: type[NonNestedDType]) -> None: + dtype = non_nested_type() + assert dtype == non_nested_type()() + assert dtype() is dtype + + +def _resolve_base_type_dtype(into: IntoDType, /) -> tuple[type[DType], DType]: + return into.base_type(), into() + + +def test_dtype___call___typing( + non_nested_type: type[NonNestedDType], nested_dtype: NestedOrEnumDType +) -> None: + r1 = _resolve_base_type_dtype(non_nested_type) + r2 = _resolve_base_type_dtype(non_nested_type()) + r3 = _resolve_base_type_dtype(nested_dtype) + r4 = _resolve_base_type_dtype(nested_dtype()) + + pytest.importorskip("typing_extensions") + from typing_extensions import assert_type + + assert_type(r1, tuple[type[DType], DType]) + assert_type(r2, tuple[type[DType], DType]) + assert_type(r3, tuple[type[DType], DType]) + assert_type(r4, tuple[type[DType], DType]) + + dtype_meta = type(DType) + assert isinstance(r1[0], dtype_meta) + assert isinstance(r1[1], DType) + assert isinstance(r2[0], dtype_meta) + assert isinstance(r2[1], DType) + assert isinstance(r3[0], dtype_meta) + assert isinstance(r3[1], DType) + assert isinstance(r4[0], dtype_meta) + assert isinstance(r4[1], DType) + + assert r1 == r2 + assert r3 == r4 + assert r2 != r3 # just to be sure + + nested_type = type(nested_dtype) + + with pytest.raises(TypeError): + _resolve_base_type_dtype(nested_type) # type: ignore[arg-type] + + @pytest.mark.parametrize( ("dtype", "context"), [