Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,34 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
def __hash__(self) -> int:
return hash(self.__class__)

def __call__(self) -> Self:
# NOTE: (Internal doc)
# But, why?
#
# Let's say we have a function that looks like this:
#
# >>> import pyarrow as pa
# >>> from narwhals.dtypes import DType
# >>> from narwhals.typing import NonNestedDType
#
# >>> def convert(dtype: DType | type[NonNestedDType]) -> pa.DataType: ...
#
#
# If we need to resolve the union to a *class*, we can use:
#
# >>> always_class = dtype.base_type()
#
# But what if instead, we need an *instance*?:
#
# >>> always_instance_1 = dtype if isinstance(dtype, DType) else dtype()
# >>> always_instance_2 = dtype() if isinstance(dtype, type) else dtype
#
# By defining `__call__`, we can save an instance check and keep things readable:
#
# >>> always_instance = dtype()
# >>> always_class = dtype.base_type()
return self
Comment on lines +178 to +204
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm intentionally not making this public API, but want there to be a reminder on how it helps 🙂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% agree to avoid making this public 👌🏼



class NumericType(DType):
"""Base class for numeric data types."""
Expand Down
49 changes: 48 additions & 1 deletion tests/dtypes/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import pytest

import narwhals as nw
from narwhals.dtypes import DType
from narwhals.exceptions import PerformanceWarning
from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION, pyspark_session

if TYPE_CHECKING:
from collections.abc import Iterable

from narwhals.typing import IntoFrame, IntoSeries, NonNestedDType
from narwhals.typing import IntoDType, IntoFrame, IntoSeries, NonNestedDType
from tests.utils import Constructor, ConstructorPandasLike, NestedOrEnumDType


Expand Down Expand Up @@ -585,6 +586,52 @@ def test_dtype_base_type_nested(nested_dtype: NestedOrEnumDType) -> None:
assert base.base_type() == nested_dtype.base_type()


def test_dtype___call___(non_nested_type: type[NonNestedDType]) -> None:
dtype = non_nested_type()
assert dtype == non_nested_type()()
assert dtype() is dtype


def _resolve_base_type_dtype(into: IntoDType, /) -> tuple[type[DType], DType]:
return into.base_type(), into()
Comment on lines +595 to +596
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is a TL;DR

We can extract the class and instance from IntoDType - only paying the cost of initialization if we needed to anyway



def test_dtype___call___typing(
non_nested_type: type[NonNestedDType], nested_dtype: NestedOrEnumDType
) -> None:
r1 = _resolve_base_type_dtype(non_nested_type)
r2 = _resolve_base_type_dtype(non_nested_type())
r3 = _resolve_base_type_dtype(nested_dtype)
r4 = _resolve_base_type_dtype(nested_dtype())

pytest.importorskip("typing_extensions")
from typing_extensions import assert_type

assert_type(r1, tuple[type[DType], DType])
assert_type(r2, tuple[type[DType], DType])
assert_type(r3, tuple[type[DType], DType])
assert_type(r4, tuple[type[DType], DType])

dtype_meta = type(DType)
assert isinstance(r1[0], dtype_meta)
assert isinstance(r1[1], DType)
assert isinstance(r2[0], dtype_meta)
assert isinstance(r2[1], DType)
assert isinstance(r3[0], dtype_meta)
assert isinstance(r3[1], DType)
assert isinstance(r4[0], dtype_meta)
assert isinstance(r4[1], DType)

assert r1 == r2
assert r3 == r4
assert r2 != r3 # just to be sure

nested_type = type(nested_dtype)

with pytest.raises(TypeError):
_resolve_base_type_dtype(nested_type) # type: ignore[arg-type]


@pytest.mark.parametrize(
("dtype", "context"),
[
Expand Down
Loading