Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f6deab5
feat(DRAFT): Adds `Compliant(When|Then)`
dangotbanned Mar 21, 2025
db5c25f
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 21, 2025
3486f55
chore(DRAFT): Try out more complete `CompliantThen`
dangotbanned Mar 21, 2025
f1689e7
fix(typing): make `SeriesT` invariant
dangotbanned Mar 21, 2025
de3ccad
fix(typing): Sneaky `_when` to avoid `_call` overloading
dangotbanned Mar 21, 2025
8d671b2
feat: Implement `Arrow(Then|When)`
dangotbanned Mar 21, 2025
1ac7daa
fix: Older `Protocol` compat?
dangotbanned Mar 21, 2025
9e7d848
chore: imports/exports
dangotbanned Mar 21, 2025
ebbefa8
feat: Implement `Dask(Then|When)`
dangotbanned Mar 21, 2025
6767c38
feat: Implement `Pandas(When|Then)`
dangotbanned Mar 21, 2025
396605e
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
5a22284
refactor: `CompliantWhen.__init__` -> `.from_expr`
dangotbanned Mar 22, 2025
c002fcf
refactor: Naming conventions
dangotbanned Mar 22, 2025
3b81fbe
refactor(typing): Clean up, expose bound aliases
dangotbanned Mar 22, 2025
f2149e6
Make annotations readable again
dangotbanned Mar 22, 2025
32d2741
typo
dangotbanned Mar 22, 2025
541f532
feat: Implement `DuckDB(When|Then)`
dangotbanned Mar 22, 2025
6183be7
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
724ff96
feat: Implement `SparkLike(When|Then)`
dangotbanned Mar 22, 2025
a9a6ccd
refactor: Add `EagerDataFrame._extract_comparand`
dangotbanned Mar 22, 2025
5adda67
feat: Implement `EagerWhen` :sunglasses:
dangotbanned Mar 22, 2025
9c9ea48
chore(typing): Partially annotate `.when`
dangotbanned Mar 22, 2025
cc5cce1
refactor: Initial `LazyWhen` idea
dangotbanned Mar 22, 2025
b0b1a05
refactor(DRAFT): `LazyWhen`
dangotbanned Mar 22, 2025
9c680dc
remove dead code
dangotbanned Mar 22, 2025
ededf10
refactor: remove from `spark_like` as well
dangotbanned Mar 22, 2025
2e38646
Merge remote-tracking branch 'upstream/main' into compliant-when
dangotbanned Mar 22, 2025
9c85a3f
Merge branch 'main' into compliant-when
dangotbanned Mar 22, 2025
ed2e7e8
Merge branch 'main' into compliant-when
MarcoGorelli Mar 23, 2025
07f51da
merge fixup
MarcoGorelli Mar 23, 2025
a339a5e
revert: redo merge conflict from (https://github.com/narwhals-dev/nar…
dangotbanned Mar 23, 2025
9632f3a
Merge branch 'compliant-when' of https://github.com/narwhals-dev/narw…
dangotbanned Mar 23, 2025
328bb43
Update imports/exports
dangotbanned Mar 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 10 additions & 66 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import Sequence
Expand All @@ -24,23 +23,20 @@
from narwhals._arrow.utils import nulls_like
from narwhals._arrow.utils import vertical_concat
from narwhals._compliant import EagerNamespace
from narwhals._compliant.when_then import CompliantThen
from narwhals._compliant.when_then import CompliantWhen
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
from typing import Callable

from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._arrow.typing import Incomplete
from narwhals.dtypes import DType
from narwhals.utils import Version

_Scalar: TypeAlias = Any


class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
@property
Expand Down Expand Up @@ -253,7 +249,7 @@ def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(self)

def when(self: Self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen(predicate, self._backend_version, version=self._version)
return ArrowWhen(predicate, context=self)

def concat_str(
self: Self,
Expand Down Expand Up @@ -293,25 +289,14 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)


class ArrowWhen:
def __init__(
self: Self,
condition: ArrowExpr,
backend_version: tuple[int, ...],
then_value: ArrowExpr | _Scalar = None,
otherwise_value: ArrowExpr | _Scalar = None,
*,
version: Version,
) -> None:
self._backend_version = backend_version
self._condition: ArrowExpr = condition
self._then_value: ArrowExpr | _Scalar = then_value
self._otherwise_value: ArrowExpr | _Scalar = otherwise_value
self._version = version
class ArrowWhen(CompliantWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
def __call__(self, df: ArrowDataFrame, /) -> Sequence[ArrowSeries]:
condition = self._condition(df)[0]
condition_native = condition._native_series
condition_native = condition.native

if isinstance(self._then_value, ArrowExpr):
value_series = self._then_value(df)[0]
Expand Down Expand Up @@ -346,46 +331,5 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
)
]

def then(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowThen:
self._then_value = value

return ArrowThen(
self,
depth=0,
function_name="whenthen",
evaluate_output_names=getattr(
value, "_evaluate_output_names", lambda _df: ["literal"]
),
alias_output_names=getattr(value, "_alias_output_names", None),
backend_version=self._backend_version,
version=self._version,
)


class ArrowThen(ArrowExpr):
def __init__(
self: Self,
call: ArrowWhen,
*,
depth: int,
function_name: str,
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
call_kwargs: dict[str, Any] | None = None,
implementation: Implementation | None = None,
) -> None:
self._backend_version = backend_version
self._version = version
self._call: ArrowWhen = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._call_kwargs = call_kwargs or {}

def otherwise(self: Self, value: ArrowExpr | ArrowSeries | _Scalar) -> ArrowExpr:
self._call._otherwise_value = value
self._function_name = "whenotherwise"
return self
class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...

@dangotbanned dangotbanned Mar 21, 2025

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before anyone asks - yes this is a fully working implementation πŸ˜„

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow!

3 changes: 3 additions & 0 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
CompliantSeriesT = TypeVar("CompliantSeriesT", bound="CompliantSeries[Any]")
CompliantSeriesOrNativeExprT = TypeVar(
"CompliantSeriesOrNativeExprT", bound="CompliantSeries[Any] | NativeExpr"
)
CompliantSeriesOrNativeExprT_co = TypeVar(
"CompliantSeriesOrNativeExprT_co",
bound="CompliantSeries[Any] | NativeExpr",
Expand Down
113 changes: 113 additions & 0 deletions narwhals/_compliant/when_then.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence
from typing import cast

from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import CompliantExprT
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT

if TYPE_CHECKING:
from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import _FullContext

if not TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 9):
from typing import Protocol as Protocol38
else:
from typing import Generic as Protocol38
else: # pragma: no cover
# TODO @dangotbanned: Remove after dropping `3.8` (#2084)
# - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386
from typing import Protocol as Protocol38

_Scalar: TypeAlias = Any


class CompliantWhen(
Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT]
):
_condition: CompliantExprT
_then_value: CompliantExprT | CompliantSeriesOrNativeExprT | _Scalar
_otherwise_value: CompliantExprT | CompliantSeriesOrNativeExprT | _Scalar
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version

@property
def _then(
self,
) -> type[
CompliantThen[CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT]
]: ...

def __call__(
self, compliant_frame: CompliantFrameT, /
) -> Sequence[CompliantSeriesOrNativeExprT]: ...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

Try to find common parts and split up their implementation

  • Ideally be able to define more as private methods here
  • Less work for subclasses

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will allow most of (Arrow|Pandas)Then.__call__ to be the same

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't expecting I'd be able to reduce it down this much ...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved from (cc5cce1) onwards

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note

Mentioned in (b0b1a05) that I'm tentative on LazyWhen

Happy to merge as-is, but its just the least-bad idea I had so far


def then(
self, value: CompliantExprT | CompliantSeriesOrNativeExprT | _Scalar, /
) -> CompliantThen[CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT]:
return self._then.from_when(self, value)

def __init__(self, condition: CompliantExprT, /, *, context: _FullContext) -> None:
self._condition = condition
self._then_value = None
self._otherwise_value = None
self._implementation = context._implementation
self._backend_version = context._backend_version
self._version = context._version
Comment thread
dangotbanned marked this conversation as resolved.
Outdated


class CompliantThen(
CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT],
Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT],
):
_call: Callable[[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]]
_when: CompliantWhen[CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT]
Comment thread
dangotbanned marked this conversation as resolved.
Outdated
_function_name: str
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
_call_kwargs: dict[str, Any]

@classmethod
def from_when(
cls,
when: CompliantWhen[
CompliantFrameT, CompliantSeriesOrNativeExprT, CompliantExprT
],
then_value: CompliantExprT | CompliantSeriesOrNativeExprT | _Scalar,
/,
) -> Self:
when._then_value = then_value
obj = cls.__new__(cls)
obj._call = when
obj._when = when
obj._depth = 0
obj._function_name = "whenthen"
obj._evaluate_output_names = getattr(
then_value, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then_value, "_alias_output_names", None)
obj._implementation = when._implementation
obj._backend_version = when._backend_version
obj._version = when._version
obj._call_kwargs = {}
return obj
Comment on lines +93 to +114

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

Probably going to try something like this for _compliant.selectors.py

Will wait for (#2266) to land first though


def otherwise(
self, value: CompliantExprT | CompliantSeriesOrNativeExprT | _Scalar, /
) -> CompliantExprT:
self._when._otherwise_value = value
self._function_name = "whenotherwise"
return cast("CompliantExprT", self)