Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 113 additions & 11 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,67 @@
from datetime import timezone
from itertools import starmap
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import Mapping
from typing import cast
from typing import overload

from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from typing import Iterator
from typing import Literal
from typing import Sequence

from typing_extensions import Self
from typing_extensions import TypeAlias
from typing_extensions import TypeIs
from typing_extensions import TypeVar

from narwhals.typing import TimeUnit

_DTypeT = TypeVar("_DTypeT", bound="DType")
UnitT = TypeVar("UnitT", bound=TimeUnit)
_UnitT = TypeVar("_UnitT", bound=TimeUnit)
IntoZone: TypeAlias = "str | timezone | None"
ZoneT = TypeVar("ZoneT", str, None)
_ZoneT = TypeVar("_ZoneT", str, None)
else:
from typing import TypeVar

UnitT = TypeVar("UnitT", bound="TimeUnit")
ZoneT = TypeVar("ZoneT", str, None)

__all__ = [
"Array",
"Boolean",
"Categorical",
"Date",
"Datetime",
"Decimal",
"Duration",
"Enum",
"Field",
"Float32",
"Float64",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"List",
"Object",
"String",
"Struct",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
]


def _validate_dtype(dtype: DType | type[DType]) -> None:
if not isinstance_or_issubclass(dtype, DType):
Expand Down Expand Up @@ -62,7 +111,7 @@ def is_temporal(cls: type[Self]) -> bool:
def is_nested(cls: type[Self]) -> bool:
return issubclass(cls, NestedType)

def __eq__(self: Self, other: DType | type[DType]) -> bool: # type: ignore[override]
def __eq__(self: _DTypeT, other: object) -> TypeIs[_DTypeT | type[_DTypeT]]: # type: ignore[override]
from narwhals.utils import isinstance_or_issubclass

return isinstance_or_issubclass(other, type(self))
Expand Down Expand Up @@ -458,7 +507,7 @@ def time_zone(cls) -> str | None:
return None


class Datetime(TemporalType, metaclass=_DatetimeMeta):
class Datetime(TemporalType, Generic[UnitT, ZoneT], metaclass=_DatetimeMeta):
"""Data type representing a calendar date and time of day.

Arguments:
Expand Down Expand Up @@ -499,10 +548,49 @@ class Datetime(TemporalType, metaclass=_DatetimeMeta):
Datetime(time_unit='ms', time_zone='Africa/Accra')
"""

time_unit: UnitT
time_zone: ZoneT

@overload
def __init__(
self: Datetime[Literal["us"], None],
time_unit: Literal["us"] = ...,
time_zone: None = ...,
) -> None: ...

@overload
def __init__(
self: Datetime[_UnitT, None], time_unit: _UnitT, time_zone: None = ...
) -> None: ...

@overload
def __init__(
self: Datetime[_UnitT, _ZoneT], time_unit: _UnitT, time_zone: _ZoneT
) -> None: ...

@overload
def __init__(
self: Self,
time_unit: TimeUnit = "us",
time_zone: str | timezone | None = None,
self: Datetime[_UnitT, str], time_unit: _UnitT, time_zone: timezone
) -> None: ...

@overload
def __init__(
self: Datetime[Literal["us"], _ZoneT],
time_unit: Literal["us"] = ...,
*,
time_zone: _ZoneT,
) -> None: ...

@overload
def __init__(
self: Datetime[Literal["us"], str],
time_unit: Literal["us"] = ...,
*,
time_zone: timezone,
) -> None: ...

def __init__(
self: Self, time_unit: TimeUnit | Literal["us"] = "us", time_zone: IntoZone = None
) -> None:
if time_unit not in {"s", "ms", "us", "ns"}:
msg = (
Expand All @@ -511,13 +599,26 @@ def __init__(
)
raise ValueError(msg)

if isinstance(time_zone, timezone):
time_zone = str(time_zone)
zone = str(time_zone) if isinstance(time_zone, timezone) else time_zone
self.time_unit = cast("UnitT", time_unit)
self.time_zone = cast("ZoneT", zone)

self.time_unit: TimeUnit = time_unit
self.time_zone: str | None = time_zone
@overload # type: ignore[override]
def __eq__(
self: Datetime[UnitT, ZoneT], other: Datetime[UnitT, ZoneT]
) -> TypeIs[Datetime[UnitT, ZoneT]]: ...

@overload
def __eq__( # type: ignore[override]
self: Self, other: type[Datetime[Any, Any]]
) -> TypeIs[type[Datetime[Any, Any]]]: ...

@overload
def __eq__(self: Self, other: Datetime) -> TypeIs[Datetime]: ...

def __eq__(self: Self, other: object) -> bool:
def __eq__( # type: ignore[override]
self: Datetime[UnitT, ZoneT], other: object
) -> TypeIs[Datetime[UnitT, ZoneT] | type[Datetime[Any, Any]] | Datetime]:
# allow comparing object instances to class
if type(other) is _DatetimeMeta:
return True
Expand Down Expand Up @@ -578,7 +679,8 @@ def __init__(self: Self, time_unit: TimeUnit = "us") -> None:

self.time_unit: TimeUnit = time_unit

def __eq__(self: Self, other: object) -> bool:
# TODO @dangotbanned: convert to `TypeIs`
def __eq__(self: Self, other: object) -> bool: # type: ignore[override]
# allow comparing object instances to class
if type(other) is _DurationMeta:
return True
Expand Down
60 changes: 59 additions & 1 deletion narwhals/stable/v1/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import TypeVar
from typing import cast
from typing import overload

from narwhals.dtypes import Array
from narwhals.dtypes import Boolean
Expand Down Expand Up @@ -37,10 +42,20 @@
from narwhals.dtypes import UnsignedIntegerType

if TYPE_CHECKING:
from datetime import timezone

from typing_extensions import Self

from narwhals.dtypes import IntoZone
from narwhals.dtypes import _UnitT
from narwhals.dtypes import _ZoneT
from narwhals.typing import TimeUnit

UnitT = TypeVar("UnitT", bound="TimeUnit")
ZoneT = TypeVar("ZoneT", str, None)


class Datetime(NwDatetime):
class Datetime(NwDatetime[UnitT, ZoneT]):
"""Data type representing a calendar date and time of day.

Arguments:
Expand All @@ -55,6 +70,49 @@ class Datetime(NwDatetime):
def __hash__(self: Self) -> int:
return hash(self.__class__)

@overload
def __init__(
self: Datetime[Literal["us"], None],
time_unit: Literal["us"] = ...,
time_zone: None = ...,
) -> None: ...

@overload
def __init__(
self: Datetime[_UnitT, None], time_unit: _UnitT, time_zone: None = ...
) -> None: ...

@overload
def __init__(
self: Datetime[_UnitT, _ZoneT], time_unit: _UnitT, time_zone: _ZoneT
) -> None: ...

@overload
def __init__(
self: Datetime[_UnitT, str], time_unit: _UnitT, time_zone: timezone
) -> None: ...

@overload
def __init__(
self: Datetime[Literal["us"], _ZoneT],
time_unit: Literal["us"] = ...,
*,
time_zone: _ZoneT,
) -> None: ...

@overload
def __init__(
self: Datetime[Literal["us"], str],
time_unit: Literal["us"] = ...,
*,
time_zone: timezone,
) -> None: ...

def __init__(
self: Self, time_unit: TimeUnit | Literal["us"] = "us", time_zone: IntoZone = None
) -> None:
super().__init__(cast("Any", time_unit), cast("Any", time_zone))


class Duration(NwDuration):
"""Data type representing a time duration.
Expand Down
46 changes: 45 additions & 1 deletion tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import polars as pl
import pyarrow as pa
import pytest
from typing_extensions import reveal_type

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
Expand Down Expand Up @@ -41,7 +42,7 @@ def test_datetime_valid(
@pytest.mark.parametrize("time_unit", ["abc"])
def test_datetime_invalid(time_unit: str) -> None:
with pytest.raises(ValueError, match="invalid `time_unit`"):
nw.Datetime(time_unit=time_unit) # type: ignore[arg-type]
nw.Datetime(time_unit=time_unit) # type: ignore[call-overload]


@pytest.mark.parametrize("time_unit", ["us", "ns", "ms"])
Expand Down Expand Up @@ -359,3 +360,46 @@ def test_cast_decimal_to_native() -> None:
.with_columns(a=nw.col("a").cast(nw.Decimal()))
.to_native()
)


def test_datetime_generic() -> None:
import narwhals as unstable_nw

dt_1 = unstable_nw.Datetime()
dt_21 = unstable_nw.Datetime("ns")
dt_22 = unstable_nw.Datetime(time_unit="ns")
dt_3 = unstable_nw.Datetime("s", time_zone="zone")
dt_4 = unstable_nw.Datetime("ns", timezone.utc)
dt_5 = unstable_nw.Datetime(time_zone="Asia/Kathmandu")
dt_6 = unstable_nw.Datetime(time_zone=timezone.utc)
reveal_type(dt_1)
reveal_type(dt_21)
reveal_type(dt_22)
reveal_type(dt_3)
reveal_type(dt_4)
reveal_type(dt_5)
reveal_type(dt_6)
reveal_type(dt_3.time_unit)
assert dt_3.time_unit

# ruff: noqa: F841

dtype = unstable_nw.Datetime("s")
bad = unstable_nw.Datetime("us", "USA")

matches_2 = dtype == unstable_nw.Datetime
matches_1 = dtype == unstable_nw.Datetime("s", None)
matches_3 = dtype == bad
matches_none = dtype == unstable_nw.Duration

if dtype == unstable_nw.Duration:
what = dtype

if dtype != unstable_nw.Datetime:
what_again = dtype

# NOTE: These **not** matching is a positive outcome
# - Omitting the overload is one way to enforce it
# - `Literal[False]` makes sense, but
if dtype == bad:
what3 = dtype
2 changes: 1 addition & 1 deletion utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@

# DTypes
dtypes = [
i for i in nw.dtypes.__dir__() if i[0].isupper() and not i.isupper() and i[0] != "_"
i for i in nw.dtypes.__all__ if i[0].isupper() and not i.isupper() and i[0] != "_"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIXME

Need some help getting this work again
https://results.pre-commit.ci/run/github/760058710/1739990897.pkejAgUjQ3CdKYXvgIDqPw

I don't understand what the two checks are doing https://narwhals-dev.github.io/narwhals/api-reference/dtypes/#narwhals.dtypes

Or why this seems to include Literal?

BASE_DTYPES = {
"NumericType",
"DType",
"TemporalType",
"Literal",
"OrderedDict",
"Mapping",
}

]
with open("docs/api-reference/dtypes.md") as fd:
content = fd.read()
Expand Down
Loading