diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 650217bd3d..ef373e5df1 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -28,12 +28,14 @@ from narwhals._typing_compat import assert_never from narwhals._utils import ( Implementation, + Version, generate_temporary_column_name, is_list_of, no_default, not_implemented, ) from narwhals.dependencies import is_numpy_array_1d +from narwhals.dtypes import _validate_cast_temporal_to_numeric from narwhals.exceptions import InvalidOperationError, ShapeError if TYPE_CHECKING: @@ -61,7 +63,7 @@ ) from narwhals._compliant.series import HistData from narwhals._typing import NoDefault - from narwhals._utils import Version, _LimitedContext + from narwhals._utils import _LimitedContext from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, @@ -543,6 +545,7 @@ def is_nan(self) -> Self: return self._with_native(pc.is_nan(self.native), preserve_broadcast=True) def cast(self, dtype: IntoDType) -> Self: + _validate_cast_temporal_to_numeric(source=self.dtype, target=dtype) data_type = narwhals_to_native_dtype(dtype, self._version) return self._with_native(pc.cast(self.native, data_type), preserve_broadcast=True) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 75a3cbb964..09f90fd71b 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -19,10 +19,12 @@ from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype from narwhals._utils import ( Implementation, + Version, generate_temporary_column_name, no_default, not_implemented, ) +from narwhals.dtypes import _validate_cast_temporal_to_numeric from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -40,7 +42,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals._typing import NoDefault - from narwhals._utils import Version, _LimitedContext + from narwhals._utils import _LimitedContext from narwhals.typing import ( FillNullStrategy, IntoDType, @@ -613,11 +615,21 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]: ) def cast(self, dtype: IntoDType) -> Self: - def func(expr: dx.Series) -> dx.Series: + def func(df: DaskLazyFrame) -> list[dx.Series]: + if dtype.is_numeric(): + schema = df.schema + for name in self._evaluate_output_names(df): + _validate_cast_temporal_to_numeric(source=schema[name], target=dtype) + native_dtype = narwhals_to_native_dtype(dtype, self._version) - return expr.astype(native_dtype) + return [expr.astype(native_dtype) for expr in self._call(df)] - return self._with_callable(func) + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + ) def is_finite(self) -> Self: import dask.array as da diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 7a0977a54c..ad1cb7a8d6 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -22,6 +22,7 @@ ) from narwhals._sql.expr import SQLExpr from narwhals._utils import Implementation, Version, extend_bool, no_default +from narwhals.dtypes import _validate_cast_temporal_to_numeric if TYPE_CHECKING: from collections.abc import Sequence @@ -38,6 +39,7 @@ ) from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.namespace import DuckDBNamespace + from narwhals._duckdb.utils import duckdb_dtypes from narwhals._typing import NoDefault from narwhals._utils import _LimitedContext from narwhals.typing import FillNullStrategy, IntoDType, RollingInterpolationMethod @@ -266,14 +268,23 @@ def _fill_constant(expr: Expression, value: Any) -> Expression: return self._with_elementwise(_fill_constant, value=value) def cast(self, dtype: IntoDType) -> Self: - def func(df: DuckDBLazyFrame) -> list[Expression]: + def _validated_dtype( + dtype: IntoDType, df: DuckDBLazyFrame + ) -> duckdb_dtypes.DuckDBPyType: + if dtype.is_numeric(): + schema = df.collect_schema() + for name in self._evaluate_output_names(df): + _validate_cast_temporal_to_numeric(source=schema[name], target=dtype) + tz = DeferredTimeZone(df.native) - native_dtype = narwhals_to_native_dtype(dtype, self._version, tz) + return narwhals_to_native_dtype(dtype, self._version, tz) + + def func(df: DuckDBLazyFrame) -> list[Expression]: + native_dtype = _validated_dtype(dtype, df) return [expr.cast(native_dtype) for expr in self(df)] def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]: - tz = DeferredTimeZone(df.native) - native_dtype = narwhals_to_native_dtype(dtype, self._version, tz) + native_dtype = _validated_dtype(dtype, df) return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)] return self.__class__( diff --git a/narwhals/_ibis/expr.py b/narwhals/_ibis/expr.py index 9de16232bb..0abc40e372 100644 --- a/narwhals/_ibis/expr.py +++ b/narwhals/_ibis/expr.py @@ -28,6 +28,7 @@ not_implemented, zip_strict, ) +from narwhals.dtypes import _validate_cast_temporal_to_numeric if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -269,12 +270,21 @@ def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value: return self._with_callable(_fill_null, value=value) def cast(self, dtype: IntoDType) -> Self: - def _func(expr: ir.Column) -> ir.Value: + def func(df: IbisLazyFrame) -> list[ir.Value]: + if dtype.is_numeric(): + schema = df.collect_schema() + for name in self._evaluate_output_names(df): + _validate_cast_temporal_to_numeric(source=schema[name], target=dtype) + native_dtype = narwhals_to_native_dtype(dtype, self._version) - # ibis `cast` overloads do not include DataType, only literals - return expr.cast(native_dtype) # type: ignore[unused-ignore] + return [expr.cast(native_dtype) for expr in self(df)] # pyright: ignore[reportArgumentType, reportCallIssue] - return self._with_callable(_func) + return self.__class__( + func, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + version=self._version, + ) def is_unique(self) -> Self: return self._with_callable( diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index f0dc6c4e48..c67132297c 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -24,8 +24,9 @@ set_index, ) from narwhals._typing_compat import assert_never -from narwhals._utils import Implementation, is_list_of, no_default +from narwhals._utils import Implementation, Version, is_list_of, no_default from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series +from narwhals.dtypes import _validate_cast_temporal_to_numeric from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: @@ -43,7 +44,7 @@ from narwhals._pandas_like.namespace import PandasLikeNamespace from narwhals._pandas_like.typing import NativeSeriesT from narwhals._typing import NoDefault - from narwhals._utils import Version, _LimitedContext + from narwhals._utils import _LimitedContext from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, @@ -311,6 +312,7 @@ def scatter( return None if in_place else self._with_native(series) def cast(self, dtype: IntoDType) -> Self: + _validate_cast_temporal_to_numeric(source=self.dtype, target=dtype) if self.dtype == dtype and self.native.dtype != "object": # Avoid dealing with pandas' type-system if we can. Note that it's only # safe to do this if we're not starting with object dtype, see tests/expr_and_series/cast_test.py::test_cast_object_pandas diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index b77bba4eb9..a5bf4a517c 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -20,8 +20,9 @@ narwhals_to_native_dtype, native_to_narwhals_dtype, ) -from narwhals._utils import Implementation, no_default, requires +from narwhals._utils import Implementation, Version, no_default, requires from narwhals.dependencies import is_numpy_array_1d, is_pandas_index +from narwhals.dtypes import _validate_cast_temporal_to_numeric if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence @@ -35,7 +36,7 @@ from narwhals._polars.dataframe import Method, PolarsDataFrame from narwhals._polars.namespace import PolarsNamespace from narwhals._typing import NoDefault - from narwhals._utils import Version, _LimitedContext + from narwhals._utils import _LimitedContext from narwhals.dtypes import DType from narwhals.series import Series from narwhals.typing import ( @@ -288,7 +289,8 @@ def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self: return self._from_native_object(self.native.__getitem__(item)) def cast(self, dtype: IntoDType) -> Self: - dtype_pl = narwhals_to_native_dtype(dtype, self._version) + _validate_cast_temporal_to_numeric(source=self.dtype, target=dtype) + dtype_pl = narwhals_to_native_dtype(dtype, version=self._version) return self._with_native(self.native.cast(dtype_pl)) def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self: diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 4fc0883e66..9f3fbeaa77 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace @@ -23,6 +24,7 @@ not_implemented, zip_strict, ) +from narwhals.dtypes import DType, _validate_cast_temporal_to_numeric if TYPE_CHECKING: from collections.abc import Iterator, Mapping, Sequence @@ -40,6 +42,7 @@ ) from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals._spark_like.utils import _NativeDType from narwhals._typing import NoDefault from narwhals._utils import _LimitedContext from narwhals.typing import FillNullStrategy, IntoDType, RankMethod @@ -246,19 +249,31 @@ def __invert__(self) -> Self: return self._with_elementwise(invert) def cast(self, dtype: IntoDType) -> Self: - def func(df: SparkLikeLazyFrame) -> Sequence[Column]: - spark_dtype = narwhals_to_native_dtype( + def _validated_dtype(dtype: IntoDType, df: SparkLikeLazyFrame) -> _NativeDType: + if dtype.is_numeric(): + schema: dict[str, DType] = {} + with suppress(Exception): + schema = df.collect_schema() + + if schema: + for name in self._evaluate_output_names(df): + _validate_cast_temporal_to_numeric( + source=schema[name], target=dtype + ) + + return narwhals_to_native_dtype( dtype, self._version, self._native_dtypes, df.native.sparkSession ) - return [expr.cast(spark_dtype) for expr in self(df)] + + def func(df: SparkLikeLazyFrame) -> Sequence[Column]: + native_dtype = _validated_dtype(dtype, df) + return [expr.cast(native_dtype) for expr in self(df)] def window_f( df: SparkLikeLazyFrame, inputs: SparkWindowInputs ) -> Sequence[Column]: - spark_dtype = narwhals_to_native_dtype( - dtype, self._version, self._native_dtypes, df.native.sparkSession - ) - return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)] + native_dtype = _validated_dtype(dtype, df) + return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)] return self.__class__( func, diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 587b6c2758..e74f8803d9 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -59,6 +59,29 @@ def _validate_into_dtype(dtype: Any) -> None: raise TypeError(msg) +def _validate_cast_temporal_to_numeric( + source: DType | type[DType], target: IntoDType +) -> None: + """Validate that we're not casting from temporal to numeric types. + + Arguments: + source: The source data type. + target: The target data type to cast to. + + Raises: + InvalidOperationError: If attempting to cast from temporal to integer. + """ + if source.is_temporal() and target.is_numeric(): + msg = ( + "Casting from temporal type to numeric is not supported.\n\n" + "Hint: Use `.dt` accessor methods instead, such as:\n" + " - `.dt.timestamp()` for Unix timestamp.\n" + " - `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.\n" + " - `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time." + ) + raise InvalidOperationError(msg) + + class DTypeClass(type): """Metaclass for DType classes. diff --git a/narwhals/expr.py b/narwhals/expr.py index 9b6f616f71..3d16bea020 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -171,6 +171,15 @@ def cast(self, dtype: IntoDType) -> Self: Arguments: dtype: Data type that the object will be cast into. + Note: + Unlike polars, we don't allow to cast from a temporal to a numeric data type. + + Use `.dt` accessor methods instead, such as: + + * `.dt.timestamp()` for Unix timestamp. + * `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components. + * `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time. + Examples: >>> import pandas as pd >>> import narwhals as nw diff --git a/narwhals/series.py b/narwhals/series.py index 6acc476245..a21e980b7b 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -628,6 +628,15 @@ def cast(self, dtype: IntoDType) -> Self: Arguments: dtype: Data type that the object will be cast into. + Note: + Unlike polars, we don't allow to cast from a temporal to a numeric data type. + + Use `.dt` accessor methods instead, such as: + + * `.dt.timestamp()` for Unix timestamp. + * `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components. + * `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time. + Examples: >>> import pyarrow as pa >>> import narwhals as nw diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index cfadaff347..32acb32381 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -6,6 +6,7 @@ import pytest import narwhals as nw +from narwhals.exceptions import InvalidOperationError from tests.utils import ( PANDAS_VERSION, PYARROW_VERSION, @@ -442,3 +443,59 @@ def test_cast_object_pandas() -> None: s = nw.from_native(pd.DataFrame({"a": [2, 3, None]}, dtype=object))["a"] assert s[0] == 2 assert s.cast(nw.String)[0] == "2" + + +NUMERIC_DTYPES = [ + nw.Int8, + nw.Int16, + nw.Int32, + nw.Int64, + nw.Float32, + nw.Float64, + nw.UInt32, + nw.UInt64, +] + + +@pytest.mark.parametrize( + "values", [[datetime(2000, 1, 1, 12, 0), None], [timedelta(365, 59), None]] +) +@pytest.mark.parametrize(("target_dtype"), NUMERIC_DTYPES) +def test_cast_temporal_to_numeric_raises_expr( + constructor: Constructor, + request: pytest.FixtureRequest, + values: list[datetime] | list[timedelta], + target_dtype: nw.dtypes.DType, +) -> None: + if "polars" in str(constructor): + reason = "Polars expressions wrap native expressions" + request.applymarker(pytest.mark.xfail(reason=reason)) + + if isinstance(values[0], timedelta) and "spark" in str(constructor): + reason = "interval not implemented" + request.applymarker(pytest.mark.xfail(reason=reason)) + + df = nw.from_native(constructor({"a": values})).lazy() + msg = "Casting from temporal type to numeric" + with pytest.raises(InvalidOperationError, match=msg): + df.select(nw.col("a").cast(target_dtype)).collect() + + +@pytest.mark.parametrize( + "values", + [ + [datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), None], + [timedelta(2, 59), timedelta(1, 59), None], + ], +) +@pytest.mark.parametrize(("target_dtype"), NUMERIC_DTYPES) +def test_cast_temporal_to_numeric_raises_series( + constructor_eager: ConstructorEager, + values: list[datetime] | list[timedelta], + target_dtype: nw.dtypes.DType, +) -> None: + df = nw.from_native(constructor_eager({"a": values}), eager_only=True) + series = df["a"] + msg = "Casting from temporal type to numeric" + with pytest.raises(InvalidOperationError, match=msg): + series.cast(target_dtype)