Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f6deab5
feat(DRAFT): Adds `Compliant(When|Then)`
dangotbanned Mar 21, 2025
db5c25f
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 21, 2025
3486f55
chore(DRAFT): Try out more complete `CompliantThen`
dangotbanned Mar 21, 2025
f1689e7
fix(typing): make `SeriesT` invariant
dangotbanned Mar 21, 2025
de3ccad
fix(typing): Sneaky `_when` to avoid `_call` overloading
dangotbanned Mar 21, 2025
8d671b2
feat: Implement `Arrow(Then|When)`
dangotbanned Mar 21, 2025
1ac7daa
fix: Older `Protocol` compat?
dangotbanned Mar 21, 2025
9e7d848
chore: imports/exports
dangotbanned Mar 21, 2025
ebbefa8
feat: Implement `Dask(Then|When)`
dangotbanned Mar 21, 2025
6767c38
feat: Implement `Pandas(When|Then)`
dangotbanned Mar 21, 2025
396605e
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
5a22284
refactor: `CompliantWhen.__init__` -> `.from_expr`
dangotbanned Mar 22, 2025
c002fcf
refactor: Naming conventions
dangotbanned Mar 22, 2025
3b81fbe
refactor(typing): Clean up, expose bound aliases
dangotbanned Mar 22, 2025
f2149e6
Make annotations readable again
dangotbanned Mar 22, 2025
32d2741
typo
dangotbanned Mar 22, 2025
541f532
feat: Implement `DuckDB(When|Then)`
dangotbanned Mar 22, 2025
6183be7
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
724ff96
feat: Implement `SparkLike(When|Then)`
dangotbanned Mar 22, 2025
a9a6ccd
refactor: Add `EagerDataFrame._extract_comparand`
dangotbanned Mar 22, 2025
5adda67
feat: Implement `EagerWhen` :sunglasses:
dangotbanned Mar 22, 2025
9c9ea48
chore(typing): Partially annotate `.when`
dangotbanned Mar 22, 2025
cc5cce1
refactor: Initial `LazyWhen` idea
dangotbanned Mar 22, 2025
b0b1a05
refactor(DRAFT): `LazyWhen`
dangotbanned Mar 22, 2025
9c680dc
remove dead code
dangotbanned Mar 22, 2025
ededf10
refactor: remove from `spark_like` as well
dangotbanned Mar 22, 2025
2e38646
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
9c85a3f
Merge branch 'main' into compliant-when
dangotbanned Mar 22, 2025
ed2e7e8
Merge branch 'main' into compliant-when
MarcoGorelli Mar 23, 2025
07f51da
merge fixup
MarcoGorelli Mar 23, 2025
a339a5e
revert: redo merge conflict from (https://github.com/narwhals-dev/nar…
dangotbanned Mar 23, 2025
9632f3a
Merge branch 'compliant-when' of https://github.com/narwhals-dev/narw…
dangotbanned Mar 23, 2025
328bb43
Update imports/exports
dangotbanned Mar 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import align_series_full_broadcast
from narwhals._arrow.utils import convert_str_slice_to_int_slice
from narwhals._arrow.utils import extract_dataframe_comparand
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import select_rows
from narwhals._compliant import EagerDataFrame
from narwhals._expression_parsing import ExprKind
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ShapeError
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import check_column_exists
Expand Down Expand Up @@ -369,21 +370,31 @@ def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
df = pa.Table.from_arrays([s._native_series for s in reshaped], names=names)
return self._from_native_frame(df, validate_column_names=True)

def _extract_comparand(self, other: ArrowSeries) -> ArrowChunkedArray:
length = len(self)
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native

import numpy as np # ignore-banned-import

value = other.native[0]
if self._backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.chunked_array([np.full(shape=length, fill_value=value)])

def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
# All `pyarrow` data is immutable, so this is fine
native_frame = self.native
new_columns = self._evaluate_into_exprs(*exprs)

length = len(self)
columns = self.columns

for col_value in new_columns:
col_name = col_value.name

column = extract_dataframe_comparand(
length=length, other=col_value, backend_version=self._backend_version
)
column = self._extract_comparand(col_value)
native_frame = (
native_frame.set_column(
columns.index(col_name),
Expand Down
118 changes: 15 additions & 103 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import Sequence

import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.dataframe import ArrowDataFrame
Expand All @@ -19,28 +18,24 @@
from narwhals._arrow.utils import align_series_full_broadcast
from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._arrow.utils import diagonal_concat
from narwhals._arrow.utils import extract_dataframe_comparand
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import nulls_like
from narwhals._arrow.utils import vertical_concat
from narwhals._compliant import CompliantThen
from narwhals._compliant import EagerNamespace
from narwhals._compliant import EagerWhen
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from typing import Callable

from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Incomplete
from narwhals.dtypes import DType
from narwhals.utils import Version

_Scalar: TypeAlias = Any


class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
@property
Expand Down Expand Up @@ -253,7 +248,7 @@ def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(self)

def when(self: Self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen(predicate, self._backend_version, version=self._version)
return ArrowWhen.from_expr(predicate, context=self)

def concat_str(
self: Self,
Expand Down Expand Up @@ -293,99 +288,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)


class ArrowWhen:
def __init__(
self: Self,
condition: ArrowExpr,
backend_version: tuple[int, ...],
then_value: ArrowExpr | _Scalar = None,
otherwise_value: ArrowExpr | _Scalar = None,
*,
version: Version,
) -> None:
self._backend_version = backend_version
self._condition: ArrowExpr = condition
self._then_value: ArrowExpr | _Scalar = then_value
self._otherwise_value: ArrowExpr | _Scalar = otherwise_value
self._version = version

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
condition = self._condition(df)[0]
condition_native = condition._native_series

if isinstance(self._then_value, ArrowExpr):
value_series = self._then_value(df)[0]
else:
value_series = condition.alias("literal")._from_scalar(self._then_value)
value_series._broadcast = True
value_series_native = extract_dataframe_comparand(
len(df), value_series, self._backend_version
)

if self._otherwise_value is None:
otherwise_null = nulls_like(len(condition_native), value_series)
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_null)
)
]
if isinstance(self._otherwise_value, ArrowExpr):
otherwise_series = self._otherwise_value(df)[0]
else:
native_result = pc.if_else(
condition_native, value_series_native, self._otherwise_value
)
return [value_series._from_native_series(native_result)]

otherwise_series_native = extract_dataframe_comparand(
len(df), otherwise_series, self._backend_version
)
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_series_native)
)
]

def then(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowThen:
self._then_value = value
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ArrowChunkedArray"]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen

return ArrowThen(
self,
depth=0,
function_name="whenthen",
evaluate_output_names=getattr(
value, "_evaluate_output_names", lambda _df: ["literal"]
),
alias_output_names=getattr(value, "_alias_output_names", None),
backend_version=self._backend_version,
version=self._version,
)
def _if_then_else(
self, when: ArrowChunkedArray, then: ArrowChunkedArray, otherwise: Any, /
) -> ArrowChunkedArray:
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
return pc.if_else(when, then, otherwise)


class ArrowThen(ArrowExpr):
def __init__(
self: Self,
call: ArrowWhen,
*,
depth: int,
function_name: str,
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
call_kwargs: dict[str, Any] | None = None,
implementation: Implementation | None = None,
) -> None:
self._backend_version = backend_version
self._version = version
self._call: ArrowWhen = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._call_kwargs = call_kwargs or {}

def otherwise(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowExpr:
self._call._otherwise_value = value
self._function_name = "whenotherwise"
return self
class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...
Copy link
Member Author

@dangotbanned dangotbanned Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before anyone asks - yes this is a fully working implementation πŸ˜„

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow!

21 changes: 0 additions & 21 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pyarrow as pa
import pyarrow.compute as pc

from narwhals.exceptions import ShapeError
from narwhals.utils import _SeriesNamespace
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
Expand Down Expand Up @@ -280,26 +279,6 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]:
return reshaped


def extract_dataframe_comparand(
length: int,
other: ArrowSeries,
backend_version: tuple[int, ...],
) -> ArrowChunkedArray:
"""Extract native Series, broadcasting to `length` if necessary."""
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native

import numpy as np # ignore-banned-import

value = other.native[0]
if backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.chunked_array([np.full(shape=length, fill_value=value)])


def horizontal_concat(dfs: list[pa.Table]) -> pa.Table:
"""Concatenate (native) DataFrames horizontally.

Expand Down
8 changes: 8 additions & 0 deletions narwhals/_compliant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from narwhals._compliant.typing import IntoCompliantExpr
from narwhals._compliant.typing import NativeFrameT_co
from narwhals._compliant.typing import NativeSeriesT_co
from narwhals._compliant.when_then import CompliantThen
from narwhals._compliant.when_then import CompliantWhen
from narwhals._compliant.when_then import EagerWhen
from narwhals._compliant.when_then import LazyWhen

__all__ = [
"CompliantDataFrame",
Expand All @@ -43,6 +47,8 @@
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"CompliantThen",
"CompliantWhen",
"DepthTrackingGroupBy",
"EagerDataFrame",
"EagerDataFrameT",
Expand All @@ -52,12 +58,14 @@
"EagerSelectorNamespace",
"EagerSeries",
"EagerSeriesT",
"EagerWhen",
"EvalNames",
"EvalSeries",
"IntoCompliantExpr",
"LazyExpr",
"LazyGroupBy",
"LazySelectorNamespace",
"LazyWhen",
"NativeFrameT_co",
"NativeSeriesT_co",
]
8 changes: 8 additions & 0 deletions narwhals/_compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def unpivot(
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str) -> Self: ...
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
result = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]


class EagerDataFrame(
Expand Down Expand Up @@ -300,3 +304,7 @@ def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSerie
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result

def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
4 changes: 4 additions & 0 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ class LazyExpr(
replace_strict: not_implemented = not_implemented()
cat: not_implemented = not_implemented() # pyright: ignore[reportAssignmentType]

@classmethod
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
return hasattr(obj, "__narwhals_expr__")


class EagerExprNamespace(_ExprNamespace[EagerExprT], Generic[EagerExprT]):
def __init__(self, expr: EagerExprT, /) -> None:
Expand Down
19 changes: 15 additions & 4 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import EagerDataFrameT
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT_co
from narwhals._compliant.typing import EagerSeriesT
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen
from narwhals._compliant.when_then import EagerWhen
from narwhals.dtypes import DType
from narwhals.utils import Implementation
from narwhals.utils import Version

Incomplete: TypeAlias = Any

__all__ = ["CompliantNamespace", "EagerNamespace"]


Expand Down Expand Up @@ -65,7 +71,9 @@ def concat(
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> CompliantFrameT: ...
def when(self, predicate: CompliantExprT) -> Any: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self,
*exprs: CompliantExprT,
Expand All @@ -80,7 +88,10 @@ def _expr(self) -> type[CompliantExprT]: ...

class EagerNamespace(
CompliantNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT],
):
@property
def _series(self) -> type[EagerSeriesT_co]: ...
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, Incomplete]: ...
Loading
Loading