diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index fad917a206..ef32c2e068 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -29,7 +29,6 @@ from narwhals._arrow.typing import Incomplete from narwhals._arrow.typing import StringArray from narwhals.dtypes import DType - from narwhals.typing import TimeUnit from narwhals.typing import _AnyDArray from narwhals.utils import Version @@ -182,12 +181,9 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa if isinstance_or_issubclass(dtype, dtypes.Categorical): return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): - time_unit: TimeUnit = getattr(dtype, "time_unit", "us") - time_zone = getattr(dtype, "time_zone", None) - return pa.timestamp(time_unit, tz=time_zone) + return pa.timestamp(dtype.time_unit, tz=dtype.time_zone) # type: ignore[arg-type] if isinstance_or_issubclass(dtype, dtypes.Duration): - time_unit = getattr(dtype, "time_unit", "us") - return pa.duration(time_unit) + return pa.duration(dtype.time_unit) if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() if isinstance_or_issubclass(dtype, dtypes.List): diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 324769ca88..1216e341eb 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -176,12 +176,12 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st msg = "Categorical not supported by DuckDB" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Datetime): - _time_unit = getattr(dtype, "time_unit", "us") - _time_zone = getattr(dtype, "time_zone", None) + _time_unit = dtype.time_unit + _time_zone = dtype.time_zone msg = "todo" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover - _time_unit = getattr(dtype, "time_unit", "us") + _time_unit = dtype.time_unit msg = "todo" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 98f613e9fa..8bb4317934 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -613,28 +613,27 @@ def narwhals_to_native_dtype( # noqa: PLR0915 # convert to it? return "category" if isinstance_or_issubclass(dtype, dtypes.Datetime): - dt_time_unit = getattr(dtype, "time_unit", "us") - dt_time_zone = getattr(dtype, "time_zone", None) - # Pandas does not support "ms" or "us" time units before version 2.0 - # Let's overwrite with "ns" if implementation is Implementation.PANDAS and backend_version < ( 2, ): # pragma: no cover dt_time_unit = "ns" + else: + dt_time_unit = dtype.time_unit if dtype_backend == "pyarrow": - tz_part = f", tz={dt_time_zone}" if dt_time_zone else "" + tz_part = f", tz={tz}" if (tz := dtype.time_zone) else "" return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]" else: - tz_part = f", {dt_time_zone}" if dt_time_zone else "" + tz_part = f", {tz}" if (tz := dtype.time_zone) else "" return f"datetime64[{dt_time_unit}{tz_part}]" if isinstance_or_issubclass(dtype, dtypes.Duration): - du_time_unit = getattr(dtype, "time_unit", "us") if implementation is Implementation.PANDAS and backend_version < ( 2, ): # pragma: no cover - dt_time_unit = "ns" + du_time_unit = "ns" + else: + du_time_unit = dtype.time_unit return ( f"duration[{du_time_unit}][pyarrow]" if dtype_backend == "pyarrow" diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index eaae9152b0..79491760c5 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -14,6 +14,7 @@ from narwhals.exceptions import NarwhalsError from narwhals.exceptions import ShapeError from narwhals.utils import import_dtypes_module +from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from narwhals._polars.dataframe import PolarsDataFrame @@ -190,13 +191,10 @@ def narwhals_to_native_dtype( if dtype == dtypes.Decimal: msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) - if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime): - dt_time_unit: TimeUnit = getattr(dtype, "time_unit", "us") - dt_time_zone = getattr(dtype, "time_zone", None) - return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type] - if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration): - du_time_unit: TimeUnit = getattr(dtype, "time_unit", "us") - return pl.Duration(time_unit=du_time_unit) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.Datetime): + return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type] + if isinstance_or_issubclass(dtype, dtypes.Duration): + return pl.Duration(dtype.time_unit) # type: ignore[arg-type] if dtype == dtypes.List: return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) # type: ignore[union-attr] if dtype == dtypes.Struct: diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 1176164d67..46ee0fd877 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -116,7 +116,7 @@ def narwhals_to_native_dtype( if isinstance_or_issubclass(dtype, dtypes.Date): return spark_types.DateType() if isinstance_or_issubclass(dtype, dtypes.Datetime): - dt_time_zone = getattr(dtype, "time_zone", None) + dt_time_zone = dtype.time_zone if dt_time_zone is None: return spark_types.TimestampNTZType() if dt_time_zone != "UTC": # pragma: no cover diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 2f408b7a35..e336958277 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -448,7 +448,17 @@ class Unknown(DType): """ -class Datetime(TemporalType): +class _DatetimeMeta(type): + @property + def time_unit(cls) -> TimeUnit: + return "us" + + @property + def time_zone(cls) -> str | None: + return None + + +class Datetime(TemporalType, metaclass=_DatetimeMeta): """Data type representing a calendar date and time of day. Arguments: @@ -505,11 +515,11 @@ def __init__( time_zone = str(time_zone) self.time_unit: TimeUnit = time_unit - self.time_zone = time_zone + self.time_zone: str | None = time_zone def __eq__(self: Self, other: object) -> bool: # allow comparing object instances to class - if type(other) is type and issubclass(other, self.__class__): + if type(other) is _DatetimeMeta: return True elif isinstance(other, self.__class__): return self.time_unit == other.time_unit and self.time_zone == other.time_zone @@ -524,7 +534,13 @@ def __repr__(self: Self) -> str: # pragma: no cover return f"{class_name}(time_unit={self.time_unit!r}, time_zone={self.time_zone!r})" -class Duration(TemporalType): +class _DurationMeta(type): + @property + def time_unit(cls) -> TimeUnit: + return "us" + + +class Duration(TemporalType, metaclass=_DurationMeta): """Data type representing a time duration. Arguments: @@ -552,10 +568,7 @@ class Duration(TemporalType): Duration(time_unit='ms') """ - def __init__( - self: Self, - time_unit: TimeUnit = "us", - ) -> None: + def __init__(self: Self, time_unit: TimeUnit = "us") -> None: if time_unit not in ("s", "ms", "us", "ns"): msg = ( "invalid `time_unit`" @@ -563,11 +576,11 @@ def __init__( ) raise ValueError(msg) - self.time_unit = time_unit + self.time_unit: TimeUnit = time_unit def __eq__(self: Self, other: object) -> bool: # allow comparing object instances to class - if type(other) is type and issubclass(other, self.__class__): + if type(other) is _DurationMeta: return True elif isinstance(other, self.__class__): return self.time_unit == other.time_unit