diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 29fc9d4d24..c7661d2a59 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import operator import re from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar @@ -44,6 +45,10 @@ from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray ExprT = TypeVar("ExprT", bound=PandasLikeExpr) + UnitCurrent: TypeAlias = TimeUnit + UnitTarget: TypeAlias = TimeUnit + BinOpBroadcast: TypeAlias = Callable[[Any, int], Any] + IntoRhs: TypeAlias = int PANDAS_LIKE_IMPLEMENTATION = { @@ -553,52 +558,47 @@ def int_dtype_mapper(dtype: Any) -> str: return "int64" -def calculate_timestamp_datetime( # noqa: C901, PLR0912 - s: NativeSeriesT, original_time_unit: str, time_unit: str +_TIMESTAMP_DATETIME_OP_FACTOR: Mapping[ + tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs] +] = { + ("ns", "us"): (operator.floordiv, 1_000), + ("ns", "ms"): (operator.floordiv, 1_000_000), + ("us", "ns"): (operator.mul, NS_PER_MICROSECOND), + ("us", "ms"): (operator.floordiv, 1_000), + ("ms", "ns"): (operator.mul, NS_PER_MILLISECOND), + ("ms", "us"): (operator.mul, 1_000), + ("s", "ns"): (operator.mul, NS_PER_SECOND), + ("s", "us"): (operator.mul, US_PER_SECOND), + ("s", "ms"): (operator.mul, MS_PER_SECOND), +} + + +def calculate_timestamp_datetime( + s: NativeSeriesT, current: TimeUnit, time_unit: TimeUnit ) -> NativeSeriesT: - if original_time_unit == "ns": - if time_unit == "ns": - result = s - elif time_unit == "us": - result = s // 1_000 - else: - result = s // 1_000_000 - elif original_time_unit == "us": - if time_unit == "ns": - result = s * NS_PER_MICROSECOND - elif time_unit == "us": - result = s - else: - result = s // 1_000 - elif original_time_unit == "ms": - if time_unit == "ns": - result = s * NS_PER_MILLISECOND - elif time_unit == "us": - result = s * 1_000 - else: - result = s - elif original_time_unit == "s": - if time_unit == "ns": - result = s * NS_PER_SECOND - elif time_unit == "us": - result = s * US_PER_SECOND - else: - result = s * MS_PER_SECOND + if current == time_unit: + return s + elif item := _TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)): + fn, factor = item + return fn(s, factor) else: # pragma: no cover - msg = f"unexpected time unit {original_time_unit}, please report a bug at https://github.com/narwhals-dev/narwhals" + msg = ( + f"unexpected time unit {current}, please report an issue at " + "https://github.com/narwhals-dev/narwhals" + ) raise AssertionError(msg) - return result - - -def calculate_timestamp_date(s: NativeSeriesT, time_unit: str) -> NativeSeriesT: - s = s * SECONDS_PER_DAY - if time_unit == "ns": - result = s * NS_PER_SECOND - elif time_unit == "us": - result = s * US_PER_SECOND - else: - result = s * MS_PER_SECOND - return result + + +_TIMESTAMP_DATE_FACTOR: Mapping[TimeUnit, int] = { + "ns": NS_PER_SECOND, + "us": US_PER_SECOND, + "ms": MS_PER_SECOND, + "s": 1, +} + + +def calculate_timestamp_date(s: NativeSeriesT, time_unit: TimeUnit) -> NativeSeriesT: + return s * SECONDS_PER_DAY * _TIMESTAMP_DATE_FACTOR[time_unit] def select_columns_by_name(