diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index e336958277..f78a7ee496 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -4,18 +4,67 @@ from datetime import timezone from itertools import starmap from typing import TYPE_CHECKING +from typing import Any +from typing import Generic from typing import Mapping +from typing import cast +from typing import overload from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from typing import Iterator + from typing import Literal from typing import Sequence from typing_extensions import Self + from typing_extensions import TypeAlias + from typing_extensions import TypeIs + from typing_extensions import TypeVar from narwhals.typing import TimeUnit + _DTypeT = TypeVar("_DTypeT", bound="DType") + UnitT = TypeVar("UnitT", bound=TimeUnit) + _UnitT = TypeVar("_UnitT", bound=TimeUnit) + IntoZone: TypeAlias = "str | timezone | None" + ZoneT = TypeVar("ZoneT", str, None) + _ZoneT = TypeVar("_ZoneT", str, None) +else: + from typing import TypeVar + + UnitT = TypeVar("UnitT", bound="TimeUnit") + ZoneT = TypeVar("ZoneT", str, None) + +__all__ = [ + "Array", + "Boolean", + "Categorical", + "Date", + "Datetime", + "Decimal", + "Duration", + "Enum", + "Field", + "Float32", + "Float64", + "Int8", + "Int16", + "Int32", + "Int64", + "Int128", + "List", + "Object", + "String", + "Struct", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "UInt128", + "Unknown", +] + def _validate_dtype(dtype: DType | type[DType]) -> None: if not isinstance_or_issubclass(dtype, DType): @@ -62,7 +111,7 @@ def is_temporal(cls: type[Self]) -> bool: def is_nested(cls: type[Self]) -> bool: return issubclass(cls, NestedType) - def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override] + def __eq__(self: _DTypeT, other: object) -> TypeIs[_DTypeT | type[_DTypeT]]: # type: ignore[override] from narwhals.utils import isinstance_or_issubclass return isinstance_or_issubclass(other, type(self)) @@ -458,7 +507,7 @@ def time_zone(cls) -> str | None: return None -class Datetime(TemporalType, metaclass=_DatetimeMeta): +class Datetime(TemporalType, Generic[UnitT, ZoneT], metaclass=_DatetimeMeta): """Data type representing a calendar date and time of day. Arguments: @@ -499,10 +548,49 @@ class Datetime(TemporalType, metaclass=_DatetimeMeta): Datetime(time_unit='ms', time_zone='Africa/Accra') """ + time_unit: UnitT + time_zone: ZoneT + + @overload + def __init__( + self: Datetime[Literal["us"], None], + time_unit: Literal["us"] = ..., + time_zone: None = ..., + ) -> None: ... + + @overload + def __init__( + self: Datetime[_UnitT, None], time_unit: _UnitT, time_zone: None = ... + ) -> None: ... + + @overload + def __init__( + self: Datetime[_UnitT, _ZoneT], time_unit: _UnitT, time_zone: _ZoneT + ) -> None: ... + + @overload def __init__( - self: Self, - time_unit: TimeUnit = "us", - time_zone: str | timezone | None = None, + self: Datetime[_UnitT, str], time_unit: _UnitT, time_zone: timezone + ) -> None: ... + + @overload + def __init__( + self: Datetime[Literal["us"], _ZoneT], + time_unit: Literal["us"] = ..., + *, + time_zone: _ZoneT, + ) -> None: ... + + @overload + def __init__( + self: Datetime[Literal["us"], str], + time_unit: Literal["us"] = ..., + *, + time_zone: timezone, + ) -> None: ... + + def __init__( + self: Self, time_unit: TimeUnit | Literal["us"] = "us", time_zone: IntoZone = None ) -> None: if time_unit not in {"s", "ms", "us", "ns"}: msg = ( @@ -511,13 +599,26 @@ def __init__( ) raise ValueError(msg) - if isinstance(time_zone, timezone): - time_zone = str(time_zone) + zone = str(time_zone) if isinstance(time_zone, timezone) else time_zone + self.time_unit = cast("UnitT", time_unit) + self.time_zone = cast("ZoneT", zone) - self.time_unit: TimeUnit = time_unit - self.time_zone: str | None = time_zone + @overload # type: ignore[override] + def __eq__( + self: Datetime[UnitT, ZoneT], other: Datetime[UnitT, ZoneT] + ) -> TypeIs[Datetime[UnitT, ZoneT]]: ... + + @overload + def __eq__( # type: ignore[override] + self: Self, other: type[Datetime[Any, Any]] + ) -> TypeIs[type[Datetime[Any, Any]]]: ... + + @overload + def __eq__(self: Self, other: Datetime) -> TypeIs[Datetime]: ... - def __eq__(self: Self, other: object) -> bool: + def __eq__( # type: ignore[override] + self: Datetime[UnitT, ZoneT], other: object + ) -> TypeIs[Datetime[UnitT, ZoneT] | type[Datetime[Any, Any]] | Datetime]: # allow comparing object instances to class if type(other) is _DatetimeMeta: return True @@ -578,7 +679,8 @@ def __init__(self: Self, time_unit: TimeUnit = "us") -> None: self.time_unit: TimeUnit = time_unit - def __eq__(self: Self, other: object) -> bool: + # TODO @dangotbanned: convert to `TypeIs` + def __eq__(self: Self, other: object) -> bool: # type: ignore[override] # allow comparing object instances to class if type(other) is _DurationMeta: return True diff --git a/narwhals/stable/v1/_dtypes.py b/narwhals/stable/v1/_dtypes.py index f7813ec71d..9ea08cc803 100644 --- a/narwhals/stable/v1/_dtypes.py +++ b/narwhals/stable/v1/_dtypes.py @@ -1,6 +1,11 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any +from typing import Literal +from typing import TypeVar +from typing import cast +from typing import overload from narwhals.dtypes import Array from narwhals.dtypes import Boolean @@ -37,10 +42,20 @@ from narwhals.dtypes import UnsignedIntegerType if TYPE_CHECKING: + from datetime import timezone + from typing_extensions import Self + from narwhals.dtypes import IntoZone + from narwhals.dtypes import _UnitT + from narwhals.dtypes import _ZoneT + from narwhals.typing import TimeUnit + +UnitT = TypeVar("UnitT", bound="TimeUnit") +ZoneT = TypeVar("ZoneT", str, None) + -class Datetime(NwDatetime): +class Datetime(NwDatetime[UnitT, ZoneT]): """Data type representing a calendar date and time of day. Arguments: @@ -55,6 +70,49 @@ class Datetime(NwDatetime): def __hash__(self: Self) -> int: return hash(self.__class__) + @overload + def __init__( + self: Datetime[Literal["us"], None], + time_unit: Literal["us"] = ..., + time_zone: None = ..., + ) -> None: ... + + @overload + def __init__( + self: Datetime[_UnitT, None], time_unit: _UnitT, time_zone: None = ... + ) -> None: ... + + @overload + def __init__( + self: Datetime[_UnitT, _ZoneT], time_unit: _UnitT, time_zone: _ZoneT + ) -> None: ... + + @overload + def __init__( + self: Datetime[_UnitT, str], time_unit: _UnitT, time_zone: timezone + ) -> None: ... + + @overload + def __init__( + self: Datetime[Literal["us"], _ZoneT], + time_unit: Literal["us"] = ..., + *, + time_zone: _ZoneT, + ) -> None: ... + + @overload + def __init__( + self: Datetime[Literal["us"], str], + time_unit: Literal["us"] = ..., + *, + time_zone: timezone, + ) -> None: ... + + def __init__( + self: Self, time_unit: TimeUnit | Literal["us"] = "us", time_zone: IntoZone = None + ) -> None: + super().__init__(cast("Any", time_unit), cast("Any", time_zone)) + class Duration(NwDuration): """Data type representing a time duration. diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index ac69b0af77..6922dbc0e1 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -11,6 +11,7 @@ import polars as pl import pyarrow as pa import pytest +from typing_extensions import reveal_type import narwhals.stable.v1 as nw from tests.utils import PANDAS_VERSION @@ -41,7 +42,7 @@ def test_datetime_valid( @pytest.mark.parametrize("time_unit", ["abc"]) def test_datetime_invalid(time_unit: str) -> None: with pytest.raises(ValueError, match="invalid `time_unit`"): - nw.Datetime(time_unit=time_unit) # type: ignore[arg-type] + nw.Datetime(time_unit=time_unit) # type: ignore[call-overload] @pytest.mark.parametrize("time_unit", ["us", "ns", "ms"]) @@ -359,3 +360,46 @@ def test_cast_decimal_to_native() -> None: .with_columns(a=nw.col("a").cast(nw.Decimal())) .to_native() ) + + +def test_datetime_generic() -> None: + import narwhals as unstable_nw + + dt_1 = unstable_nw.Datetime() + dt_21 = unstable_nw.Datetime("ns") + dt_22 = unstable_nw.Datetime(time_unit="ns") + dt_3 = unstable_nw.Datetime("s", time_zone="zone") + dt_4 = unstable_nw.Datetime("ns", timezone.utc) + dt_5 = unstable_nw.Datetime(time_zone="Asia/Kathmandu") + dt_6 = unstable_nw.Datetime(time_zone=timezone.utc) + reveal_type(dt_1) + reveal_type(dt_21) + reveal_type(dt_22) + reveal_type(dt_3) + reveal_type(dt_4) + reveal_type(dt_5) + reveal_type(dt_6) + reveal_type(dt_3.time_unit) + assert dt_3.time_unit + + # ruff: noqa: F841 + + dtype = unstable_nw.Datetime("s") + bad = unstable_nw.Datetime("us", "USA") + + matches_2 = dtype == unstable_nw.Datetime + matches_1 = dtype == unstable_nw.Datetime("s", None) + matches_3 = dtype == bad + matches_none = dtype == unstable_nw.Duration + + if dtype == unstable_nw.Duration: + what = dtype + + if dtype != unstable_nw.Datetime: + what_again = dtype + + # NOTE: These **not** matching is a positive outcome + # - Omitting the overload is one way to enforce it + # - `Literal[False]` makes sense, but + if dtype == bad: + what3 = dtype diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 87f1b237ae..659fdbd7d4 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -215,7 +215,7 @@ # DTypes dtypes = [ - i for i in nw.dtypes.__dir__() if i[0].isupper() and not i.isupper() and i[0] != "_" + i for i in nw.dtypes.__all__ if i[0].isupper() and not i.isupper() and i[0] != "_" ] with open("docs/api-reference/dtypes.md") as fd: content = fd.read()