diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 8a9a9bfd1d..dd5ec8c284 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -185,6 +185,14 @@ def is_cudf_index(index: Any) -> TypeIs[cudf.Index]: ) # pragma: no cover +def is_cupy_scalar(obj: Any) -> bool: + return ( + (cupy := get_cupy()) is not None + and isinstance(obj, cupy.ndarray) + and obj.size == 1 + ) # pragma: no cover + + def is_dask_dataframe(df: Any) -> TypeIs[dd.DataFrame]: """Check whether `df` is a Dask DataFrame without importing Dask.""" return (dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame) @@ -227,6 +235,10 @@ def is_pyarrow_table(df: Any) -> TypeIs[pa.Table]: return (pa := get_pyarrow()) is not None and isinstance(df, pa.Table) +def is_pyarrow_scalar(obj: Any) -> TypeIs[pa.Scalar[Any]]: + return (pa := get_pyarrow()) is not None and isinstance(obj, pa.Scalar) + + def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]: """Check whether `df` is a PySpark DataFrame without importing PySpark.""" return bool( diff --git a/narwhals/translate.py b/narwhals/translate.py index 81790d1a0e..1837e954aa 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -1,7 +1,6 @@ from __future__ import annotations -from datetime import datetime -from datetime import timedelta +import datetime as dt from decimal import Decimal from functools import wraps from typing import TYPE_CHECKING @@ -16,7 +15,6 @@ from narwhals._namespace import is_native_polars from narwhals._namespace import is_native_spark_like from narwhals.dependencies import get_cudf -from narwhals.dependencies import get_cupy from narwhals.dependencies import get_dask from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_modin @@ -26,11 +24,13 @@ from narwhals.dependencies import get_pyarrow from narwhals.dependencies import is_cudf_dataframe from narwhals.dependencies import is_cudf_series +from narwhals.dependencies import is_cupy_scalar from narwhals.dependencies import is_dask_dataframe from narwhals.dependencies import is_duckdb_relation from narwhals.dependencies import is_ibis_table from narwhals.dependencies import is_modin_dataframe from narwhals.dependencies import is_modin_series +from narwhals.dependencies import is_numpy_scalar from narwhals.dependencies import is_pandas_dataframe from narwhals.dependencies import is_pandas_like_dataframe from narwhals.dependencies import is_pandas_series @@ -38,6 +38,7 @@ from narwhals.dependencies import is_polars_lazyframe from narwhals.dependencies import is_polars_series from narwhals.dependencies import is_pyarrow_chunked_array +from narwhals.dependencies import is_pyarrow_scalar from narwhals.dependencies import is_pyarrow_table from narwhals.utils import Version @@ -66,6 +67,7 @@ complex, Decimal, ) +TEMPORAL_SCALAR_TYPES = (dt.date, dt.timedelta, dt.time) @overload @@ -773,7 +775,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator(func) -def to_py_scalar(scalar_like: Any) -> Any: # noqa: C901, PLR0911, PLR0912 +def to_py_scalar(scalar_like: Any) -> Any: """If a scalar is not Python native, converts it to Python native. Arguments: @@ -798,56 +800,42 @@ def to_py_scalar(scalar_like: Any) -> Any: # noqa: C901, PLR0911, PLR0912 >>> nw.to_py_scalar(1) 1 """ - if scalar_like is None: - return None - if isinstance(scalar_like, NON_TEMPORAL_SCALAR_TYPES): - return scalar_like - - np = get_numpy() - if ( - np + scalar: Any + pd = get_pandas() + if scalar_like is None or isinstance(scalar_like, NON_TEMPORAL_SCALAR_TYPES): + scalar = scalar_like + elif ( + (np := get_numpy()) and isinstance(scalar_like, np.datetime64) and scalar_like.dtype == "datetime64[ns]" ): - return datetime(1970, 1, 1) + timedelta(microseconds=scalar_like.item() // 1000) - - if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"): - return scalar_like.item() - - pd = get_pandas() - if pd and isinstance(scalar_like, pd.Timestamp): - return scalar_like.to_pydatetime() - if pd and isinstance(scalar_like, pd.Timedelta): - return scalar_like.to_pytimedelta() - if pd and pd.api.types.is_scalar(scalar_like): - try: - is_na = pd.isna(scalar_like) - except Exception: # pragma: no cover # noqa: BLE001, S110 - pass - else: - if is_na: - return None - + ms = scalar_like.item() // 1000 + scalar = dt.datetime(1970, 1, 1) + dt.timedelta(microseconds=ms) + elif is_numpy_scalar(scalar_like) or is_cupy_scalar(scalar_like): + scalar = scalar_like.item() + elif pd and isinstance(scalar_like, pd.Timestamp): + scalar = scalar_like.to_pydatetime() + elif pd and isinstance(scalar_like, pd.Timedelta): + scalar = scalar_like.to_pytimedelta() # pd.Timestamp and pd.Timedelta subclass datetime and timedelta, # so we need to check this separately - if isinstance(scalar_like, (datetime, timedelta)): - return scalar_like - - pa = get_pyarrow() - if pa and isinstance(scalar_like, pa.Scalar): - return scalar_like.as_py() + elif isinstance(scalar_like, TEMPORAL_SCALAR_TYPES): + scalar = scalar_like + elif _is_pandas_na(scalar_like): + scalar = None + elif is_pyarrow_scalar(scalar_like): + scalar = scalar_like.as_py() + else: + msg = ( + f"Expected object convertible to a scalar, found {type(scalar_like)}.\n" + f"{scalar_like!r}" + ) + raise ValueError(msg) + return scalar - cupy = get_cupy() - if ( # pragma: no cover - cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1 - ): - return scalar_like.item() - msg = ( - f"Expected object convertible to a scalar, found {type(scalar_like)}. " - "Please report a bug to https://github.com/narwhals-dev/narwhals/issues" - ) - raise ValueError(msg) +def _is_pandas_na(obj: Any) -> bool: + return bool((pd := get_pandas()) and pd.api.types.is_scalar(obj) and pd.isna(obj)) __all__ = [ diff --git a/tests/translate/to_py_scalar_test.py b/tests/translate/to_py_scalar_test.py index d8086b839d..a8a92c817f 100644 --- a/tests/translate/to_py_scalar_test.py +++ b/tests/translate/to_py_scalar_test.py @@ -1,6 +1,8 @@ from __future__ import annotations +from datetime import date from datetime import datetime +from datetime import time from datetime import timedelta from decimal import Decimal from typing import Any @@ -29,6 +31,8 @@ (b"a", b"a"), (datetime(2021, 1, 1), datetime(2021, 1, 1)), (timedelta(days=1), timedelta(days=1)), + (date(1980, 1, 1), date(1980, 1, 1)), + (time(9, 45), time(9, 45)), (pd.Timestamp("2020-01-01"), datetime(2020, 1, 1)), (pd.Timedelta(days=3), timedelta(days=3)), (np.datetime64("2020-01-01", "s"), datetime(2020, 1, 1)),