Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
is_list_of,
no_default,
not_implemented,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError, ShapeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,7 @@
)
from narwhals._compliant.series import HistData
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -543,6 +545,7 @@ def is_nan(self) -> Self:
return self._with_native(pc.is_nan(self.native), preserve_broadcast=True)

def cast(self, dtype: IntoDType) -> Self:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
data_type = narwhals_to_native_dtype(dtype, self._version)
return self._with_native(pc.cast(self.native, data_type), preserve_broadcast=True)

Expand Down
20 changes: 16 additions & 4 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
no_default,
not_implemented,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -40,7 +42,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
Expand Down Expand Up @@ -613,11 +615,21 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
)

def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
def func(df: DaskLazyFrame) -> list[dx.Series]:
if dtype.is_numeric():
schema = df.schema
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
return [expr.astype(native_dtype) for expr in self._call(df)]

return self._with_callable(func)
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_finite(self) -> Self:
import dask.array as da
Expand Down
19 changes: 15 additions & 4 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, extend_bool, no_default
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -38,6 +39,7 @@
)
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.utils import duckdb_dtypes
from narwhals._typing import NoDefault
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, RollingInterpolationMethod
Expand Down Expand Up @@ -266,14 +268,23 @@ def _fill_constant(expr: Expression, value: Any) -> Expression:
return self._with_elementwise(_fill_constant, value=value)

def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
def _validated_dtype(
dtype: IntoDType, df: DuckDBLazyFrame
) -> duckdb_dtypes.DuckDBPyType:
if dtype.is_numeric():
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return narwhals_to_native_dtype(dtype, self._version, tz)

def func(df: DuckDBLazyFrame) -> list[Expression]:
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
Expand Down
18 changes: 14 additions & 4 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -269,12 +270,21 @@ def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value:
return self._with_callable(_fill_null, value=value)

def cast(self, dtype: IntoDType) -> Self:
def _func(expr: ir.Column) -> ir.Value:
def func(df: IbisLazyFrame) -> list[ir.Value]:
if dtype.is_numeric():
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

native_dtype = narwhals_to_native_dtype(dtype, self._version)
# ibis `cast` overloads do not include DataType, only literals
return expr.cast(native_dtype) # type: ignore[unused-ignore]
return [expr.cast(native_dtype) for expr in self(df)] # pyright: ignore[reportArgumentType, reportCallIssue]

return self._with_callable(_func)
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_unique(self) -> Self:
return self._with_callable(
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
set_index,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import Implementation, is_list_of, no_default
from narwhals._utils import Implementation, Version, is_list_of, no_default
from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -43,7 +44,7 @@
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.typing import NativeSeriesT
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -311,6 +312,7 @@ def scatter(
return None if in_place else self._with_native(series)

def cast(self, dtype: IntoDType) -> Self:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
if self.dtype == dtype and self.native.dtype != "object":
# Avoid dealing with pandas' type-system if we can. Note that it's only
# safe to do this if we're not starting with object dtype, see tests/expr_and_series/cast_test.py::test_cast_object_pandas
Expand Down
8 changes: 5 additions & 3 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
narwhals_to_native_dtype,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation, no_default, requires
from narwhals._utils import Implementation, Version, no_default, requires
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -35,7 +36,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Expand Down Expand Up @@ -288,7 +289,8 @@ def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
return self._from_native_object(self.native.__getitem__(item))

def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
dtype_pl = narwhals_to_native_dtype(dtype, version=self._version)
return self._with_native(self.native.cast(dtype_pl))

def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self:
Expand Down
29 changes: 22 additions & 7 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast

from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
Expand All @@ -23,6 +24,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import DType, _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
Expand All @@ -40,6 +42,7 @@
)
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.utils import _NativeDType
from narwhals._typing import NoDefault
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, RankMethod
Expand Down Expand Up @@ -246,19 +249,31 @@ def __invert__(self) -> Self:
return self._with_elementwise(invert)

def cast(self, dtype: IntoDType) -> Self:
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
def _validated_dtype(dtype: IntoDType, df: SparkLikeLazyFrame) -> _NativeDType:
if dtype.is_numeric():
schema: dict[str, DType] = {}
with suppress(Exception):
schema = df.collect_schema()

if schema:
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(
source=schema[name], target=dtype
)
Comment on lines +254 to +262
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not particularly proud of this piece of code 🤔


return narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self(df)]

def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
func,
Expand Down
23 changes: 23 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ def _validate_into_dtype(dtype: Any) -> None:
raise TypeError(msg)


def _validate_cast_temporal_to_numeric(
source: DType | type[DType], target: IntoDType
) -> None:
"""Validate that we're not casting from temporal to numeric types.

Arguments:
source: The source data type.
target: The target data type to cast to.

Raises:
InvalidOperationError: If attempting to cast from temporal to integer.
"""
if source.is_temporal() and target.is_numeric():
msg = (
"Casting from temporal type to numeric is not supported.\n\n"
"Hint: Use `.dt` accessor methods instead, such as:\n"
" - `.dt.timestamp()` for Unix timestamp.\n"
" - `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.\n"
" - `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time."
)
raise InvalidOperationError(msg)


class DTypeClass(type):
"""Metaclass for DType classes.

Expand Down
9 changes: 9 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: polars allows also casting to Float, not only to Integer


Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
Expand Down
9 changes: 9 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.

Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
Expand Down
Loading
Loading