Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def is_cudf_index(index: Any) -> TypeIs[cudf.Index]:
) # pragma: no cover


def is_cupy_scalar(obj: Any) -> bool:
return (
(cupy := get_cupy()) is not None
and isinstance(obj, cupy.ndarray)
and obj.size == 1
) # pragma: no cover


def is_dask_dataframe(df: Any) -> TypeIs[dd.DataFrame]:
"""Check whether `df` is a Dask DataFrame without importing Dask."""
return (dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame)
Expand Down Expand Up @@ -227,6 +235,10 @@ def is_pyarrow_table(df: Any) -> TypeIs[pa.Table]:
return (pa := get_pyarrow()) is not None and isinstance(df, pa.Table)


def is_pyarrow_scalar(obj: Any) -> TypeIs[pa.Scalar[Any]]:
return (pa := get_pyarrow()) is not None and isinstance(obj, pa.Scalar)


def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:
"""Check whether `df` is a PySpark DataFrame without importing PySpark."""
return bool(
Expand Down
82 changes: 35 additions & 47 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from datetime import datetime
from datetime import timedelta
import datetime as dt
from decimal import Decimal
from functools import wraps
from typing import TYPE_CHECKING
Expand All @@ -16,7 +15,6 @@
from narwhals._namespace import is_native_polars
from narwhals._namespace import is_native_spark_like
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_cupy
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_modin
Expand All @@ -26,18 +24,21 @@
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import is_cudf_dataframe
from narwhals.dependencies import is_cudf_series
from narwhals.dependencies import is_cupy_scalar
from narwhals.dependencies import is_dask_dataframe
from narwhals.dependencies import is_duckdb_relation
from narwhals.dependencies import is_ibis_table
from narwhals.dependencies import is_modin_dataframe
from narwhals.dependencies import is_modin_series
from narwhals.dependencies import is_numpy_scalar
from narwhals.dependencies import is_pandas_dataframe
from narwhals.dependencies import is_pandas_like_dataframe
from narwhals.dependencies import is_pandas_series
from narwhals.dependencies import is_polars_dataframe
from narwhals.dependencies import is_polars_lazyframe
from narwhals.dependencies import is_polars_series
from narwhals.dependencies import is_pyarrow_chunked_array
from narwhals.dependencies import is_pyarrow_scalar
from narwhals.dependencies import is_pyarrow_table
from narwhals.utils import Version

Expand Down Expand Up @@ -66,6 +67,7 @@
complex,
Decimal,
)
TEMPORAL_SCALAR_TYPES = (dt.date, dt.timedelta, dt.time)


@overload
Expand Down Expand Up @@ -773,7 +775,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return decorator(func)


def to_py_scalar(scalar_like: Any) -> Any: # noqa: C901, PLR0911, PLR0912
def to_py_scalar(scalar_like: Any) -> Any:
"""If a scalar is not Python native, converts it to Python native.

Arguments:
Expand All @@ -798,56 +800,42 @@ def to_py_scalar(scalar_like: Any) -> Any: # noqa: C901, PLR0911, PLR0912
>>> nw.to_py_scalar(1)
1
"""
if scalar_like is None:
return None
if isinstance(scalar_like, NON_TEMPORAL_SCALAR_TYPES):
return scalar_like

np = get_numpy()
if (
np
scalar: Any
pd = get_pandas()
if scalar_like is None or isinstance(scalar_like, NON_TEMPORAL_SCALAR_TYPES):
scalar = scalar_like
elif (
(np := get_numpy())
and isinstance(scalar_like, np.datetime64)
and scalar_like.dtype == "datetime64[ns]"
):
return datetime(1970, 1, 1) + timedelta(microseconds=scalar_like.item() // 1000)

if np and np.isscalar(scalar_like) and hasattr(scalar_like, "item"):
return scalar_like.item()

pd = get_pandas()
if pd and isinstance(scalar_like, pd.Timestamp):
return scalar_like.to_pydatetime()
if pd and isinstance(scalar_like, pd.Timedelta):
return scalar_like.to_pytimedelta()
if pd and pd.api.types.is_scalar(scalar_like):
try:
is_na = pd.isna(scalar_like)
except Exception: # pragma: no cover # noqa: BLE001, S110
pass
else:
if is_na:
return None

ms = scalar_like.item() // 1000
scalar = dt.datetime(1970, 1, 1) + dt.timedelta(microseconds=ms)
elif is_numpy_scalar(scalar_like) or is_cupy_scalar(scalar_like):
scalar = scalar_like.item()
elif pd and isinstance(scalar_like, pd.Timestamp):
scalar = scalar_like.to_pydatetime()
elif pd and isinstance(scalar_like, pd.Timedelta):
scalar = scalar_like.to_pytimedelta()
# pd.Timestamp and pd.Timedelta subclass datetime and timedelta,
# so we need to check this separately
if isinstance(scalar_like, (datetime, timedelta)):
return scalar_like

pa = get_pyarrow()
if pa and isinstance(scalar_like, pa.Scalar):
return scalar_like.as_py()
elif isinstance(scalar_like, TEMPORAL_SCALAR_TYPES):
scalar = scalar_like
elif _is_pandas_na(scalar_like):
scalar = None
elif is_pyarrow_scalar(scalar_like):
scalar = scalar_like.as_py()
else:
msg = (
f"Expected object convertible to a scalar, found {type(scalar_like)}.\n"
f"{scalar_like!r}"
)
raise ValueError(msg)
return scalar

cupy = get_cupy()
if ( # pragma: no cover
cupy and isinstance(scalar_like, cupy.ndarray) and scalar_like.size == 1
):
return scalar_like.item()

msg = (
f"Expected object convertible to a scalar, found {type(scalar_like)}. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)
def _is_pandas_na(obj: Any) -> bool:
return bool((pd := get_pandas()) and pd.api.types.is_scalar(obj) and pd.isna(obj))


__all__ = [
Expand Down
4 changes: 4 additions & 0 deletions tests/translate/to_py_scalar_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from datetime import date
from datetime import datetime
from datetime import time
from datetime import timedelta
from decimal import Decimal
from typing import Any
Expand Down Expand Up @@ -29,6 +31,8 @@
(b"a", b"a"),
(datetime(2021, 1, 1), datetime(2021, 1, 1)),
(timedelta(days=1), timedelta(days=1)),
(date(1980, 1, 1), date(1980, 1, 1)),
(time(9, 45), time(9, 45)),
(pd.Timestamp("2020-01-01"), datetime(2020, 1, 1)),
(pd.Timedelta(days=3), timedelta(days=3)),
(np.datetime64("2020-01-01", "s"), datetime(2020, 1, 1)),
Expand Down
Loading