diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index d88d80fd4e..b4d65c9e76 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -13,14 +13,15 @@ import pyarrow as pa import pyarrow.compute as pc +from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import convert_str_slice_to_int_slice -from narwhals._arrow.utils import extract_dataframe_comparand from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import select_rows from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind from narwhals.dependencies import is_numpy_array_1d +from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import check_column_exists @@ -369,21 +370,31 @@ def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: df = pa.Table.from_arrays([s._native_series for s in reshaped], names=names) return self._from_native_frame(df, validate_column_names=True) + def _extract_comparand(self, other: ArrowSeries) -> ArrowChunkedArray: + length = len(self) + if not other._broadcast: + if (len_other := len(other)) != length: + msg = f"Expected object of length {length}, got: {len_other}." + raise ShapeError(msg) + return other.native + + import numpy as np # ignore-banned-import + + value = other.native[0] + if self._backend_version < (13,) and hasattr(value, "as_py"): + value = value.as_py() + return pa.chunked_array([np.full(shape=length, fill_value=value)]) + def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) # All `pyarrow` data is immutable, so this is fine native_frame = self.native new_columns = self._evaluate_into_exprs(*exprs) - - length = len(self) columns = self.columns for col_value in new_columns: col_name = col_value.name - - column = extract_dataframe_comparand( - length=length, other=col_value, backend_version=self._backend_version - ) + column = self._extract_comparand(col_value) native_frame = ( native_frame.set_column( columns.index(col_name), diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index a44a045d9d..543a9d167b 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -5,11 +5,10 @@ from itertools import chain from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterable from typing import Literal -from typing import Sequence +import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.dataframe import ArrowDataFrame @@ -19,28 +18,24 @@ from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import cast_to_comparable_string_types from narwhals._arrow.utils import diagonal_concat -from narwhals._arrow.utils import extract_dataframe_comparand from narwhals._arrow.utils import horizontal_concat -from narwhals._arrow.utils import nulls_like from narwhals._arrow.utils import vertical_concat +from narwhals._compliant import CompliantThen from narwhals._compliant import EagerNamespace +from narwhals._compliant import EagerWhen from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module if TYPE_CHECKING: - from typing import Callable - from typing_extensions import Self - from typing_extensions import TypeAlias + from narwhals._arrow.typing import ArrowChunkedArray from narwhals._arrow.typing import Incomplete from narwhals.dtypes import DType from narwhals.utils import Version - _Scalar: TypeAlias = Any - class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]): @property @@ -253,7 +248,7 @@ def selectors(self: Self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace(self) def when(self: Self, predicate: ArrowExpr) -> ArrowWhen: - return ArrowWhen(predicate, self._backend_version, version=self._version) + return ArrowWhen.from_expr(predicate, context=self) def concat_str( self: Self, @@ -293,99 +288,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) -class ArrowWhen: - def __init__( - self: Self, - condition: ArrowExpr, - backend_version: tuple[int, ...], - then_value: ArrowExpr | _Scalar = None, - otherwise_value: ArrowExpr | _Scalar = None, - *, - version: Version, - ) -> None: - self._backend_version = backend_version - self._condition: ArrowExpr = condition - self._then_value: ArrowExpr | _Scalar = then_value - self._otherwise_value: ArrowExpr | _Scalar = otherwise_value - self._version = version - - def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: - condition = self._condition(df)[0] - condition_native = condition._native_series - - if isinstance(self._then_value, ArrowExpr): - value_series = self._then_value(df)[0] - else: - value_series = condition.alias("literal")._from_scalar(self._then_value) - value_series._broadcast = True - value_series_native = extract_dataframe_comparand( - len(df), value_series, self._backend_version - ) - - if self._otherwise_value is None: - otherwise_null = nulls_like(len(condition_native), value_series) - return [ - value_series._from_native_series( - pc.if_else(condition_native, value_series_native, otherwise_null) - ) - ] - if isinstance(self._otherwise_value, ArrowExpr): - otherwise_series = self._otherwise_value(df)[0] - else: - native_result = pc.if_else( - condition_native, value_series_native, self._otherwise_value - ) - return [value_series._from_native_series(native_result)] - - otherwise_series_native = extract_dataframe_comparand( - len(df), otherwise_series, self._backend_version - ) - return [ - value_series._from_native_series( - pc.if_else(condition_native, value_series_native, otherwise_series_native) - ) - ] - - def then(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowThen: - self._then_value = value +class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ArrowChunkedArray"]): + @property + def _then(self) -> type[ArrowThen]: + return ArrowThen - return ArrowThen( - self, - depth=0, - function_name="whenthen", - evaluate_output_names=getattr( - value, "_evaluate_output_names", lambda _df: ["literal"] - ), - alias_output_names=getattr(value, "_alias_output_names", None), - backend_version=self._backend_version, - version=self._version, - ) + def _if_then_else( + self, when: ArrowChunkedArray, then: ArrowChunkedArray, otherwise: Any, / + ) -> ArrowChunkedArray: + otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise + return pc.if_else(when, then, otherwise) -class ArrowThen(ArrowExpr): - def __init__( - self: Self, - call: ArrowWhen, - *, - depth: int, - function_name: str, - evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], - alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - backend_version: tuple[int, ...], - version: Version, - call_kwargs: dict[str, Any] | None = None, - implementation: Implementation | None = None, - ) -> None: - self._backend_version = backend_version - self._version = version - self._call: ArrowWhen = call - self._depth = depth - self._function_name = function_name - self._evaluate_output_names = evaluate_output_names - self._alias_output_names = alias_output_names - self._call_kwargs = call_kwargs or {} - - def otherwise(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowExpr: - self._call._otherwise_value = value - self._function_name = "whenotherwise" - return self +class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ... diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index abfb11cb6a..7964492041 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -13,7 +13,6 @@ import pyarrow as pa import pyarrow.compute as pc -from narwhals.exceptions import ShapeError from narwhals.utils import _SeriesNamespace from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass @@ -280,26 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: return reshaped -def extract_dataframe_comparand( - length: int, - other: ArrowSeries, - backend_version: tuple[int, ...], -) -> ArrowChunkedArray: - """Extract native Series, broadcasting to `length` if necessary.""" - if not other._broadcast: - if (len_other := len(other)) != length: - msg = f"Expected object of length {length}, got: {len_other}." - raise ShapeError(msg) - return other.native - - import numpy as np # ignore-banned-import - - value = other.native[0] - if backend_version < (13,) and hasattr(value, "as_py"): - value = value.as_py() - return pa.chunked_array([np.full(shape=length, fill_value=value)]) - - def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: """Concatenate (native) DataFrames horizontally. diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index f223aa6a3f..570aaad17b 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -29,6 +29,10 @@ from narwhals._compliant.typing import IntoCompliantExpr from narwhals._compliant.typing import NativeFrameT_co from narwhals._compliant.typing import NativeSeriesT_co +from narwhals._compliant.when_then import CompliantThen +from narwhals._compliant.when_then import CompliantWhen +from narwhals._compliant.when_then import EagerWhen +from narwhals._compliant.when_then import LazyWhen __all__ = [ "CompliantDataFrame", @@ -43,6 +47,8 @@ "CompliantSeries", "CompliantSeriesOrNativeExprT_co", "CompliantSeriesT", + "CompliantThen", + "CompliantWhen", "DepthTrackingGroupBy", "EagerDataFrame", "EagerDataFrameT", @@ -52,12 +58,14 @@ "EagerSelectorNamespace", "EagerSeries", "EagerSeriesT", + "EagerWhen", "EvalNames", "EvalSeries", "IntoCompliantExpr", "LazyExpr", "LazyGroupBy", "LazySelectorNamespace", + "LazyWhen", "NativeFrameT_co", "NativeSeriesT_co", ] diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 3c42403bba..4abc90218c 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -268,6 +268,10 @@ def unpivot( ) -> Self: ... def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ... def with_row_index(self, name: str) -> Self: ... + def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: + result = expr(self) + assert len(result) == 1 # debug assertion # noqa: S101 + return result[0] class EagerDataFrame( @@ -300,3 +304,7 @@ def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSerie msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}" raise AssertionError(msg) return result + + def _extract_comparand(self, other: EagerSeriesT, /) -> Any: + """Extract native Series, broadcasting to `len(self)` if necessary.""" + ... diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 7fe41baf6f..8d4187cb50 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -871,6 +871,10 @@ class LazyExpr( replace_strict: not_implemented = not_implemented() cat: not_implemented = not_implemented() # pyright: ignore[reportAssignmentType] + @classmethod + def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]: + return hasattr(obj, "__narwhals_expr__") + class EagerExprNamespace(_ExprNamespace[EagerExprT], Generic[EagerExprT]): def __init__(self, expr: EagerExprT, /) -> None: diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index f5449ec404..2521f4e889 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -12,17 +12,23 @@ from narwhals._compliant.typing import CompliantFrameT from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerExprT -from narwhals._compliant.typing import EagerSeriesT_co +from narwhals._compliant.typing import EagerSeriesT from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names from narwhals.utils import passthrough_column_names if TYPE_CHECKING: + from typing_extensions import TypeAlias + from narwhals._compliant.selectors import CompliantSelectorNamespace + from narwhals._compliant.when_then import CompliantWhen + from narwhals._compliant.when_then import EagerWhen from narwhals.dtypes import DType from narwhals.utils import Implementation from narwhals.utils import Version + Incomplete: TypeAlias = Any + __all__ = ["CompliantNamespace", "EagerNamespace"] @@ -65,7 +71,9 @@ def concat( *, how: Literal["horizontal", "vertical", "diagonal"], ) -> CompliantFrameT: ... - def when(self, predicate: CompliantExprT) -> Any: ... + def when( + self, predicate: CompliantExprT + ) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ... def concat_str( self, *exprs: CompliantExprT, @@ -80,7 +88,10 @@ def _expr(self) -> type[CompliantExprT]: ... class EagerNamespace( CompliantNamespace[EagerDataFrameT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], + Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT], ): @property - def _series(self) -> type[EagerSeriesT_co]: ... + def _series(self) -> type[EagerSeriesT]: ... + def when( + self, predicate: EagerExprT + ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, Incomplete]: ... diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index c1771bde16..4d7abaa0e5 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -33,54 +33,64 @@ "NativeFrameT_co", "NativeSeriesT_co", ] +CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]" +CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]" +CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr" +CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any]" +CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any]" +CompliantFrameAny: TypeAlias = "CompliantDataFrameAny | CompliantLazyFrameAny" + +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" +EagerSeriesAny: TypeAlias = "EagerSeries[Any]" +EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" +EagerNamespaceAny: TypeAlias = ( + "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny]" +) + +LazyExprAny: TypeAlias = "LazyExpr[Any, Any]" + +NativeExprT = TypeVar("NativeExprT", bound="NativeExpr") NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True) +NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries") NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True) -CompliantSeriesT = TypeVar("CompliantSeriesT", bound="CompliantSeries[Any]") +NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True) + +CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny) +CompliantExprT_contra = TypeVar( + "CompliantExprT_contra", bound=CompliantExprAny, contravariant=True +) +CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny) +CompliantSeriesOrNativeExprT = TypeVar( + "CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny +) CompliantSeriesOrNativeExprT_co = TypeVar( "CompliantSeriesOrNativeExprT_co", - bound="CompliantSeries[Any] | NativeExpr", + bound=CompliantSeriesOrNativeExprAny, covariant=True, ) -NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True) -CompliantFrameT = TypeVar( - "CompliantFrameT", - bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", -) +CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny) CompliantFrameT_co = TypeVar( - "CompliantFrameT_co", - bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", - covariant=True, -) -CompliantDataFrameT = TypeVar( - "CompliantDataFrameT", bound="CompliantDataFrame[Any, Any, Any]" + "CompliantFrameT_co", bound=CompliantFrameAny, covariant=True ) +CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny) CompliantDataFrameT_co = TypeVar( - "CompliantDataFrameT_co", bound="CompliantDataFrame[Any, Any, Any]", covariant=True + "CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True ) -CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") +CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny) CompliantLazyFrameT_co = TypeVar( - "CompliantLazyFrameT_co", bound="CompliantLazyFrame[Any, Any]", covariant=True + "CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True ) IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" -CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]" -CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny) -CompliantExprT_contra = TypeVar( - "CompliantExprT_contra", bound=CompliantExprAny, contravariant=True -) +EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny) +EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True) +EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny) +EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True) + +# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`? EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") -EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]") -EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound="EagerSeries[Any]", covariant=True) -EagerExprT = TypeVar("EagerExprT", bound="EagerExpr[Any, Any]") -EagerExprT_contra = TypeVar( - "EagerExprT_contra", bound="EagerExpr[Any, Any]", contravariant=True -) -EagerNamespaceAny: TypeAlias = ( - "EagerNamespace[EagerDataFrame[Any, Any, Any], EagerSeries[Any], EagerExpr[Any, Any]]" -) -LazyExprT_contra = TypeVar( - "LazyExprT_contra", bound="LazyExpr[Any, Any]", contravariant=True -) + +LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True) AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]] AliasName: TypeAlias = Callable[[str], str] diff --git a/narwhals/_compliant/when_then.py b/narwhals/_compliant/when_then.py new file mode 100644 index 0000000000..f611062d30 --- /dev/null +++ b/narwhals/_compliant/when_then.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Sequence +from typing import TypeVar +from typing import cast + +from narwhals._compliant.expr import CompliantExpr +from narwhals._compliant.typing import CompliantExprAny +from narwhals._compliant.typing import CompliantFrameAny +from narwhals._compliant.typing import CompliantLazyFrameT +from narwhals._compliant.typing import CompliantSeriesOrNativeExprAny +from narwhals._compliant.typing import EagerDataFrameT +from narwhals._compliant.typing import EagerExprT +from narwhals._compliant.typing import EagerSeriesT +from narwhals._compliant.typing import LazyExprAny +from narwhals._compliant.typing import NativeExprT +from narwhals._compliant.typing import NativeSeriesT + +if TYPE_CHECKING: + from typing_extensions import Self + from typing_extensions import TypeAlias + + from narwhals.utils import Implementation + from narwhals.utils import Version + from narwhals.utils import _FullContext + +if not TYPE_CHECKING: # pragma: no cover + if sys.version_info >= (3, 9): + from typing import Protocol as Protocol38 + else: + from typing import Generic as Protocol38 +else: # pragma: no cover + # TODO @dangotbanned: Remove after dropping `3.8` (#2084) + # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 + from typing import Protocol as Protocol38 + +__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen", "LazyWhen"] + +ExprT = TypeVar("ExprT", bound=CompliantExprAny) +LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) +SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny) +FrameT = TypeVar("FrameT", bound=CompliantFrameAny) + +Scalar: TypeAlias = Any +"""A native or python literal value.""" + +IntoExpr: TypeAlias = "SeriesT | ExprT | Scalar" +"""Anything that is convertible into a `CompliantExpr`.""" + + +class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]): + _condition: ExprT + _then_value: IntoExpr[SeriesT, ExprT] + _otherwise_value: IntoExpr[SeriesT, ExprT] + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + @property + def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT]]: ... + def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ... + + def then( + self, value: IntoExpr[SeriesT, ExprT], / + ) -> CompliantThen[FrameT, SeriesT, ExprT]: + return self._then.from_when(self, value) + + @classmethod + def from_expr(cls, condition: ExprT, /, *, context: _FullContext) -> Self: + obj = cls.__new__(cls) + obj._condition = condition + obj._then_value = None + obj._otherwise_value = None + obj._implementation = context._implementation + obj._backend_version = context._backend_version + obj._version = context._version + return obj + + +class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT, ExprT]): + _call: Callable[[FrameT], Sequence[SeriesT]] + _when_value: CompliantWhen[FrameT, SeriesT, ExprT] + _function_name: str + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + _call_kwargs: dict[str, Any] + + @classmethod + def from_when( + cls, + when: CompliantWhen[FrameT, SeriesT, ExprT], + then: IntoExpr[SeriesT, ExprT], + /, + ) -> Self: + when._then_value = then + obj = cls.__new__(cls) + obj._call = when + obj._when_value = when + obj._depth = 0 + obj._function_name = "whenthen" + obj._evaluate_output_names = getattr( + then, "_evaluate_output_names", lambda _df: ["literal"] + ) + obj._alias_output_names = getattr(then, "_alias_output_names", None) + obj._implementation = when._implementation + obj._backend_version = when._backend_version + obj._version = when._version + obj._call_kwargs = {} + return obj + + def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT: + self._when_value._otherwise_value = otherwise + self._function_name = "whenotherwise" + return cast("ExprT", self) + + +class EagerWhen( + CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT], + Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT], +): + def _if_then_else( + self, + when: NativeSeriesT, + then: NativeSeriesT, + otherwise: NativeSeriesT | Scalar | None, + /, + ) -> NativeSeriesT: ... + + def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]: + is_expr = self._condition._is_expr + when: EagerSeriesT = self._condition(df)[0] + then: EagerSeriesT + if is_expr(self._then_value): + then = self._then_value(df)[0] + else: + then = when.alias("literal")._from_scalar(self._then_value) + then._broadcast = True + if is_expr(self._otherwise_value): + otherwise = df._extract_comparand(self._otherwise_value(df)[0]) + else: + otherwise = self._otherwise_value + result = self._if_then_else(when.native, df._extract_comparand(then), otherwise) + return [then._from_native_series(result)] + + +class LazyWhen( + CompliantWhen[CompliantLazyFrameT, NativeExprT, LazyExprT], + Protocol38[CompliantLazyFrameT, NativeExprT, LazyExprT], +): + when: Callable[..., NativeExprT] + lit: Callable[..., NativeExprT] + + def __call__(self: Self, df: CompliantLazyFrameT) -> Sequence[NativeExprT]: + is_expr = self._condition._is_expr + when = self.when + lit = self.lit + condition = df._evaluate_expr(self._condition) + then_ = self._then_value + then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_) + other_ = self._otherwise_value + if other_ is None: + result = when(condition, then) + else: + otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_) + result = when(condition, then).otherwise(otherwise) # type: ignore # noqa: PGH003 + return [result] diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 9ea92b75c9..e9b232aa6f 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -4,7 +4,6 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterable from typing import Literal from typing import Sequence @@ -13,6 +12,8 @@ import pandas as pd from narwhals._compliant import CompliantNamespace +from narwhals._compliant import CompliantThen +from narwhals._compliant import CompliantWhen from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace @@ -37,7 +38,7 @@ import dask_expr as dx -class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]): +class DaskNamespace(CompliantNamespace[DaskLazyFrame, DaskExpr]): _implementation: Implementation = Implementation.DASK @property @@ -257,7 +258,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) def when(self: Self, predicate: DaskExpr) -> DaskWhen: - return DaskWhen(predicate, self._backend_version, version=self._version) + return DaskWhen.from_expr(predicate, context=self) def concat_str( self: Self, @@ -307,21 +308,10 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) -class DaskWhen: - def __init__( - self: Self, - condition: DaskExpr, - backend_version: tuple[int, ...], - then_value: Any = None, - otherwise_value: Any = None, - *, - version: Version, - ) -> None: - self._backend_version = backend_version - self._condition: DaskExpr = condition - self._then_value: DaskExpr | Any = then_value - self._otherwise_value: DaskExpr | Any = otherwise_value - self._version = version +class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): + @property + def _then(self) -> type[DaskThen]: + return DaskThen def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]: condition = self._condition(df)[0] @@ -339,50 +329,10 @@ def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]: if isinstance(self._otherwise_value, DaskExpr): otherwise_value = self._otherwise_value(df)[0] else: - return [then_series.where(condition, self._otherwise_value)] + return [then_series.where(condition, self._otherwise_value)] # pyright: ignore[reportArgumentType] (otherwise_series,) = align_series_full_broadcast(df, otherwise_value) validate_comparand(condition, otherwise_series) return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType] - def then(self: Self, value: DaskExpr | Any) -> DaskThen: - self._then_value = value - return DaskThen( - self, - depth=0, - function_name="whenthen", - evaluate_output_names=getattr( - value, "_evaluate_output_names", lambda _df: ["literal"] - ), - alias_output_names=getattr(value, "_alias_output_names", None), - backend_version=self._backend_version, - version=self._version, - ) - - -class DaskThen(DaskExpr): - def __init__( - self: Self, - call: DaskWhen, - *, - depth: int, - function_name: str, - evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], - alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - backend_version: tuple[int, ...], - version: Version, - call_kwargs: dict[str, Any] | None = None, - ) -> None: - self._backend_version = backend_version - self._version = version - self._call: DaskWhen = call - self._depth = depth - self._function_name = function_name - self._evaluate_output_names = evaluate_output_names - self._alias_output_names = alias_output_names - self._call_kwargs = call_kwargs or {} - - def otherwise(self: Self, value: DaskExpr | Any) -> DaskExpr: - self._call._otherwise_value = value - self._function_name = "whenotherwise" - return self +class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ... diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index b54dc83a7c..47241ddfd3 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -25,7 +25,6 @@ from narwhals._duckdb.utils import generate_order_by_sql from narwhals._duckdb.utils import generate_partition_by_sql from narwhals._duckdb.utils import lit -from narwhals._duckdb.utils import maybe_evaluate_expr from narwhals._duckdb.utils import narwhals_to_native_dtype from narwhals._expression_parsing import ExprKind from narwhals.utils import Implementation @@ -173,9 +172,9 @@ def _from_call( """ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - native_series_list = self._call(df) + native_series_list = self(df) other_native_series = { - key: maybe_evaluate_expr(df, value) + key: df._evaluate_expr(value) if self._is_expr(value) else lit(value) for key, value in expressifiable_args.items() } return [ diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 121b91916d..b0263522e6 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -4,11 +4,11 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterable from typing import Literal from typing import Sequence +import duckdb from duckdb import CaseExpression from duckdb import CoalesceOperator from duckdb import FunctionExpression @@ -16,17 +16,18 @@ from duckdb.typing import VARCHAR from narwhals._compliant import CompliantNamespace +from narwhals._compliant import CompliantThen +from narwhals._compliant import LazyWhen from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.selectors import DuckDBSelectorNamespace from narwhals._duckdb.utils import lit -from narwhals._duckdb.utils import maybe_evaluate_expr from narwhals._duckdb.utils import narwhals_to_native_dtype +from narwhals._duckdb.utils import when from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation if TYPE_CHECKING: - import duckdb from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame @@ -217,11 +218,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: ) def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen: - return DuckDBWhen( - predicate, - self._backend_version, - version=self._version, - ) + return DuckDBWhen.from_expr(predicate, context=self) def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: @@ -256,71 +253,17 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: ) -class DuckDBWhen: - def __init__( - self: Self, - condition: DuckDBExpr, - backend_version: tuple[int, ...], - then_value: Any = None, - otherwise_value: Any = None, - *, - version: Version, - ) -> None: - self._backend_version = backend_version - self._condition = condition - self._then_value = then_value - self._otherwise_value = otherwise_value - self._version = version +class DuckDBWhen(LazyWhen["DuckDBLazyFrame", duckdb.Expression, DuckDBExpr]): + @property + def _then(self) -> type[DuckDBThen]: + return DuckDBThen def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: - condition = maybe_evaluate_expr(df, self._condition) - then_value = maybe_evaluate_expr(df, self._then_value) - if self._otherwise_value is None: - return [CaseExpression(condition=condition, value=then_value)] - otherwise_value = maybe_evaluate_expr(df, self._otherwise_value) - return [ - CaseExpression(condition=condition, value=then_value).otherwise( - otherwise_value - ) - ] - - def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: - self._then_value = value - - return DuckDBThen( - self, - function_name="whenthen", - evaluate_output_names=getattr( - value, "_evaluate_output_names", lambda _df: ["literal"] - ), - alias_output_names=getattr(value, "_alias_output_names", None), - backend_version=self._backend_version, - version=self._version, - ) + self.when = when + self.lit = lit + return super().__call__(df) -class DuckDBThen(DuckDBExpr): - def __init__( - self: Self, - call: DuckDBWhen, - *, - function_name: str, - evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], - alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - backend_version: tuple[int, ...], - version: Version, - ) -> None: - self._backend_version = backend_version - self._version = version - self._call = call - self._function_name = function_name - self._evaluate_output_names = evaluate_output_names - self._alias_output_names = alias_output_names - - def otherwise(self: Self, value: DuckDBExpr | Any) -> DuckDBExpr: - # type ignore because we are setting the `_call` attribute to a - # callable object of type `DuckDBWhen`, base class has the attribute as - # only a `Callable` - self._call._otherwise_value = value # type: ignore[attr-defined] - self._function_name = "whenotherwise" - return self +class DuckDBThen( + CompliantThen["DuckDBLazyFrame", duckdb.Expression, DuckDBExpr], DuckDBExpr +): ... diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index a9f0364f8d..c00d2e95cf 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -20,6 +20,9 @@ lit = duckdb.ConstantExpression """Alias for `duckdb.ConstantExpression`.""" +when = duckdb.CaseExpression +"""Alias for `duckdb.CaseExpression`.""" + class WindowInputs: __slots__ = ("expr", "order_by", "partition_by") @@ -35,18 +38,6 @@ def __init__( self.order_by = order_by -def maybe_evaluate_expr( - df: DuckDBLazyFrame, obj: DuckDBExpr | object -) -> duckdb.Expression: - from narwhals._duckdb.expr import DuckDBExpr - - if isinstance(obj, DuckDBExpr): - column_results = obj._call(df) - assert len(column_results) == 1 # debug assertion # noqa: S101 - return column_results[0] - return lit(obj) - - def evaluate_exprs( df: DuckDBLazyFrame, /, *exprs: DuckDBExpr ) -> list[tuple[str, duckdb.Expression]]: diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8605eac3ee..52242c1bc4 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -17,15 +17,16 @@ from narwhals._pandas_like.utils import align_series_full_broadcast from narwhals._pandas_like.utils import check_column_names_are_unique from narwhals._pandas_like.utils import convert_str_slice_to_int_slice -from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import object_native_to_narwhals_dtype from narwhals._pandas_like.utils import pivot_table from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name +from narwhals._pandas_like.utils import set_index from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import check_column_exists from narwhals.utils import generate_temporary_column_name @@ -149,6 +150,23 @@ def _from_native_frame( validate_column_names=validate_column_names, ) + def _extract_comparand(self, other: PandasLikeSeries) -> pd.Series[Any]: + index = self.native.index + if other._broadcast: + s = other.native + return type(s)(s.iloc[0], index=index, dtype=s.dtype, name=s.name) + if (len_other := len(other)) != (len_idx := len(index)): + msg = f"Expected object of length {len_idx}, got: {len_other}." + raise ShapeError(msg) + if other.native.index is not index: + return set_index( + other.native, + index, + implementation=other._implementation, + backend_version=other._backend_version, + ) + return other.native + def get_column(self: Self, name: str) -> PandasLikeSeries: return PandasLikeSeries( self.native[name], @@ -450,8 +468,7 @@ def filter( else: # `[0]` is safe as the predicate's expression only returns a single column mask = self._evaluate_into_exprs(predicate)[0] - mask_native = extract_dataframe_comparand(self.native.index, mask) - + mask_native = self._extract_comparand(mask) return self._from_native_frame( self.native.loc[mask_native], validate_column_names=False ) @@ -459,7 +476,6 @@ def filter( def with_columns( self: PandasLikeDataFrame, *exprs: PandasLikeExpr ) -> PandasLikeDataFrame: - index = self.native.index new_columns = self._evaluate_into_exprs(*exprs) if not new_columns and len(self) == 0: return self @@ -470,14 +486,12 @@ def with_columns( for name in self.native.columns: if name in new_column_name_to_new_column_map: to_concat.append( - extract_dataframe_comparand( - index, new_column_name_to_new_column_map.pop(name) - ) + self._extract_comparand(new_column_name_to_new_column_map.pop(name)) ) else: to_concat.append(self.native[name]) to_concat.extend( - extract_dataframe_comparand(index, new_column_name_to_new_column_map[s]) + self._extract_comparand(new_column_name_to_new_column_map[s]) for s in new_column_name_to_new_column_map ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index dc525cb634..5352cfa8dc 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -4,12 +4,12 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterable from typing import Literal -from typing import Sequence +from narwhals._compliant import CompliantThen from narwhals._compliant import EagerNamespace +from narwhals._compliant import EagerWhen from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._pandas_like.dataframe import PandasLikeDataFrame @@ -18,21 +18,18 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.utils import align_series_full_broadcast from narwhals._pandas_like.utils import diagonal_concat -from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat from narwhals.utils import import_dtypes_module if TYPE_CHECKING: + import pandas as pd from typing_extensions import Self - from typing_extensions import TypeAlias from narwhals.dtypes import DType from narwhals.utils import Implementation from narwhals.utils import Version - _Scalar: TypeAlias = Any - class PandasLikeNamespace( EagerNamespace[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr] @@ -264,9 +261,7 @@ def concat( raise NotImplementedError def when(self: Self, predicate: PandasLikeExpr) -> PandasWhen: - return PandasWhen( - predicate, self._implementation, self._backend_version, version=self._version - ) + return PandasWhen.from_expr(predicate, context=self) def concat_str( self: Self, @@ -318,106 +313,19 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) -class PandasWhen: - def __init__( - self: Self, - condition: PandasLikeExpr, - implementation: Implementation, - backend_version: tuple[int, ...], - then_value: PandasLikeExpr | _Scalar = None, - otherwise_value: PandasLikeExpr | _Scalar = None, - *, - version: Version, - ) -> None: - self._implementation = implementation - self._backend_version = backend_version - self._condition: PandasLikeExpr = condition - self._then_value: PandasLikeExpr | _Scalar = then_value - self._otherwise_value: PandasLikeExpr | _Scalar = otherwise_value - self._version = version - - def __call__(self: Self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: - condition = self._condition(df)[0] - condition_native = condition._native_series - - if isinstance(self._then_value, PandasLikeExpr): - value_series = self._then_value(df)[0] - else: - value_series = condition.alias("literal")._from_scalar(self._then_value) - value_series._broadcast = True - value_series_native = extract_dataframe_comparand( - df._native_frame.index, value_series - ) - - if self._otherwise_value is None: - return [ - value_series._from_native_series( - value_series_native.where(condition_native) - ) - ] - - if isinstance(self._otherwise_value, PandasLikeExpr): - otherwise_series = self._otherwise_value(df)[0] - else: - native_result = value_series_native.where( - condition_native, self._otherwise_value - ) - return [value_series._from_native_series(native_result)] - otherwise_series_native = extract_dataframe_comparand( - df._native_frame.index, otherwise_series - ) - return [ - value_series._from_native_series( - value_series_native.where(condition_native, otherwise_series_native) - ) - ] - - def then( - self: Self, value: PandasLikeExpr | PandasLikeSeries | _Scalar - ) -> PandasThen: - self._then_value = value +class PandasWhen( + EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, "pd.Series[Any]"] +): + @property + def _then(self) -> type[PandasThen]: + return PandasThen - return PandasThen( - self, - depth=0, - function_name="whenthen", - evaluate_output_names=getattr( - value, "_evaluate_output_names", lambda _df: ["literal"] - ), - alias_output_names=getattr(value, "_alias_output_names", None), - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) + def _if_then_else( + self, when: pd.Series[Any], then: pd.Series[Any], otherwise: Any, / + ) -> pd.Series[Any]: + return then.where(when) if otherwise is None else then.where(when, otherwise) -class PandasThen(PandasLikeExpr): - def __init__( - self: Self, - call: PandasWhen, - *, - depth: int, - function_name: str, - evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], - alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - implementation: Implementation, - backend_version: tuple[int, ...], - version: Version, - call_kwargs: dict[str, Any] | None = None, - ) -> None: - self._implementation = implementation - self._backend_version = backend_version - self._version = version - self._call: PandasWhen = call - self._depth = depth - self._function_name = function_name - self._evaluate_output_names = evaluate_output_names - self._alias_output_names = alias_output_names - self._call_kwargs = call_kwargs or {} - - def otherwise( - self: Self, value: PandasLikeExpr | PandasLikeSeries | _Scalar - ) -> PandasLikeExpr: - self._call._otherwise_value = value - self._function_name = "whenotherwise" - return self +class PandasThen( + CompliantThen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr], PandasLikeExpr +): ... diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 8fbee07f82..b15544a453 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -15,7 +15,6 @@ from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import DuplicateError -from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import import_dtypes_module @@ -132,26 +131,6 @@ def align_and_extract_native( return lhs_native, rhs -def extract_dataframe_comparand( - index: pd.Index[Any], other: PandasLikeSeries -) -> pd.Series[Any]: - """Extract native Series, broadcasting to `length` if necessary.""" - if other._broadcast: - s = other._native_series - return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name) - if (len_other := len(other)) != (len_idx := len(index)): - msg = f"Expected object of length {len_idx}, got: {len_other}." - raise ShapeError(msg) - if other._native_series.index is not index: - return set_index( - other._native_series, - index, - implementation=other._implementation, - backend_version=other._backend_version, - ) - return other._native_series - - def horizontal_concat( dfs: list[Any], *, implementation: Implementation, backend_version: tuple[int, ...] ) -> Any: diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 09c24042c9..c0f702c86d 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -432,6 +432,8 @@ class PolarsLazyFrame: tail: Method[Self] unique: Method[Self] with_columns: Method[Self] + # NOTE: Temporary, just trying to factor out utils + _evaluate_expr: Any def __init__( self: Self, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index b437da098e..e6c66ffa7c 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -19,7 +19,6 @@ from narwhals._spark_like.utils import import_functions from narwhals._spark_like.utils import import_native_dtypes from narwhals._spark_like.utils import import_window -from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.dependencies import get_pyspark from narwhals.utils import Implementation @@ -225,9 +224,10 @@ def _from_call( **expressifiable_args: Self | Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: - native_series_list = self._call(df) + native_series_list = self(df) + lit = df._F.lit other_native_series = { - key: maybe_evaluate_expr(df, value) + key: df._evaluate_expr(value) if self._is_expr(value) else lit(value) for key, value in expressifiable_args.items() } return [ diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 6988ba9cc3..13d8bb0907 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -3,19 +3,18 @@ import operator from functools import reduce from typing import TYPE_CHECKING -from typing import Any -from typing import Callable from typing import Iterable from typing import Literal from typing import Sequence from narwhals._compliant import CompliantNamespace +from narwhals._compliant import CompliantThen +from narwhals._compliant import LazyWhen from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.selectors import SparkLikeSelectorNamespace -from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype if TYPE_CHECKING: @@ -284,82 +283,20 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) def when(self: Self, predicate: SparkLikeExpr) -> SparkLikeWhen: - return SparkLikeWhen( - predicate, - self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return SparkLikeWhen.from_expr(predicate, context=self) -class SparkLikeWhen: - def __init__( - self: Self, - condition: SparkLikeExpr, - backend_version: tuple[int, ...], - then_value: Any | None = None, - otherwise_value: Any | None = None, - *, - version: Version, - implementation: Implementation, - ) -> None: - self._backend_version = backend_version - self._condition = condition - self._then_value = then_value - self._otherwise_value = otherwise_value - self._version = version - self._implementation = implementation - - def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: - condition = maybe_evaluate_expr(df, self._condition) - then_value = maybe_evaluate_expr(df, self._then_value) - if self._otherwise_value is None: - return [df._F.when(condition=condition, value=then_value)] - otherwise_value = maybe_evaluate_expr(df, self._otherwise_value) - return [ - df._F.when(condition=condition, value=then_value).otherwise(otherwise_value) - ] - - def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: - self._then_value = value - - return SparkLikeThen( - self, - function_name="whenthen", - evaluate_output_names=getattr( - value, "_evaluate_output_names", lambda _df: ["literal"] - ), - alias_output_names=getattr(value, "_alias_output_names", None), - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) +class SparkLikeWhen(LazyWhen[SparkLikeLazyFrame, "Column", SparkLikeExpr]): + @property + def _then(self) -> type[SparkLikeThen]: + return SparkLikeThen + def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: + self.when = df._F.when + self.lit = df._F.lit + return super().__call__(df) -class SparkLikeThen(SparkLikeExpr): - def __init__( - self: Self, - call: SparkLikeWhen, - *, - function_name: str, - evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], - alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - backend_version: tuple[int, ...], - version: Version, - implementation: Implementation, - ) -> None: - self._backend_version = backend_version - self._version = version - self._call = call - self._function_name = function_name - self._evaluate_output_names = evaluate_output_names - self._alias_output_names = alias_output_names - self._implementation = implementation - def otherwise(self: Self, value: SparkLikeExpr | Any) -> SparkLikeExpr: - # type ignore because we are setting the `_call` attribute to a - # callable object of type `SparkLikeWhen`, base class has the attribute as - # only a `Callable` - self._call._otherwise_value = value # type: ignore[attr-defined] - self._function_name = "whenotherwise" - return self +class SparkLikeThen( + CompliantThen[SparkLikeLazyFrame, "Column", SparkLikeExpr], SparkLikeExpr +): ... diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index c52952955a..6c7b5d2139 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -192,16 +192,6 @@ def evaluate_exprs( return native_results -def maybe_evaluate_expr(df: SparkLikeLazyFrame, obj: SparkLikeExpr | object) -> Column: - from narwhals._spark_like.expr import SparkLikeExpr - - if isinstance(obj, SparkLikeExpr): - column_results = obj._call(df) - assert len(column_results) == 1 # debug assertion # noqa: S101 - return column_results[0] - return df._F.lit(obj) - - def _std( column: Column, ddof: int,