Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions narwhals/_plan/_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@


Combination: TypeAlias = Union[
ir.SortBy,
ir.BinaryExpr,
ir.TernaryExpr,
ir.Filter,
ir.OrderedWindowExpr,
ir.WindowExpr,
ir.SortBy, ir.BinaryExpr, ir.TernaryExpr, ir.Filter, ir.OverOrdered, ir.Over
]


Expand Down Expand Up @@ -268,11 +263,11 @@ def _expand_only(self, origin: ExprIR, child: ExprIR, /) -> ExprIR:
# TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves
def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]:
changes: dict[str, Any] = {}
if isinstance(origin, (ir.WindowExpr, ir.Filter, ir.SortBy)):
if isinstance(origin, ir.WindowExpr):
if isinstance(origin, (ir.Over, ir.Filter, ir.SortBy)):
if isinstance(origin, ir.Over):
if partition_by := origin.partition_by:
changes["partition_by"] = tuple(self._expand_inner(partition_by))
if isinstance(origin, ir.OrderedWindowExpr):
if isinstance(origin, ir.OverOrdered):
changes["order_by"] = tuple(self._expand_inner(origin.order_by))
elif isinstance(origin, ir.SortBy):
changes["by"] = tuple(self._expand_inner(origin.by))
Expand Down Expand Up @@ -355,7 +350,7 @@ def _expand_function_expr(
ir.BinaryExpr,
ir.TernaryExpr,
ir.Filter,
ir.OrderedWindowExpr,
ir.WindowExpr,
ir.OverOrdered,
ir.Over,
)
"""more than one (direct) child and those can be nested."""
8 changes: 0 additions & 8 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,6 @@ def iter_right(self) -> Iterator[ExprIR]:
for node in reversed(child): # pragma: no cover
yield from node.iter_right()

def iter_root_names(self) -> Iterator[ExprIR]:
"""Override for different iteration behavior in `ExprIR.meta.root_names`.

Note:
Identical to `iter_left` by default.
"""
yield from self.iter_left()

def iter_output_name(self) -> Iterator[ExprIR]:
"""Override for different iteration behavior in `ExprIR.meta.output_name`.

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_plan/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSerie
)


def is_window_expr(obj: Any) -> TypeIs[ir.WindowExpr]:
return isinstance(obj, _ir().WindowExpr)
def is_over(obj: Any) -> TypeIs[ir.Over]:
return isinstance(obj, _ir().Over)


def is_function_expr(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]:
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_plan/_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
is_aggregation,
is_binary_expr,
is_function_expr,
is_window_expr,
is_over,
)
from narwhals._plan._parse import parse_into_seq_of_expr_ir
from narwhals._plan.common import replace
Expand Down Expand Up @@ -49,7 +49,7 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR:
[discord-0]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384807793512677398
"""
if (
is_window_expr(window)
is_over(window)
and is_function_expr(window.expr)
and window.expr.options.is_elementwise()
):
Expand All @@ -75,7 +75,7 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR:
[discord-2]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384869107203047588
"""
if (
is_window_expr(window)
is_over(window)
and is_binary_expr(window.expr)
and (is_aggregation(window.expr.right))
):
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def skew(self, node: FExpr[F.Skew], frame: Frame, name: str) -> Scalar:

def over(
self,
node: ir.WindowExpr,
node: ir.Over,
frame: Frame,
name: str,
*,
Expand All @@ -558,7 +558,7 @@ def over(
return self.from_series(results.get_column(name))

def over_ordered(
self, node: ir.OrderedWindowExpr, frame: Frame, name: str
self, node: ir.OverOrdered, frame: Frame, name: str
) -> Self | Scalar:
by = node.order_by_names()
indices = fn.sort_indices(frame.native, *by, options=node.sort_options)
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_plan/compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ def is_not_null(
self, node: FunctionExpr[IsNotNull], frame: FrameT_contra, name: str
) -> Self: ...
def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ...
def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ...
def over(self, node: ir.Over, frame: FrameT_contra, name: str) -> Self: ...
# NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr`
# e.g. `nw.col("a").first().over(order_by="b")`
def over_ordered(
self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str
self, node: ir.OverOrdered, frame: FrameT_contra, name: str
) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ...
def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ...
def rolling_expr(
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_plan/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def binary_expr_length_changing_error(

# TODO @dangotbanned: Use arguments in error message
def over_nested_error(
expr: ir.WindowExpr, # noqa: ARG001
expr: ir.Over, # noqa: ARG001
partition_by: Seq[ir.ExprIR], # noqa: ARG001
order_by: Seq[ir.ExprIR] = (), # noqa: ARG001
sort_options: SortOptions | None = None, # noqa: ARG001
Expand Down Expand Up @@ -159,7 +159,7 @@ def over_row_separable_error(


def over_order_by_names_error(
expr: ir.OrderedWindowExpr, by: ir.ExprIR
expr: ir.OverOrdered, by: ir.ExprIR
) -> InvalidOperationError:
if by.meta.is_column_selection(allow_aliasing=True):
# narwhals dev error
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_plan/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
InvertSelector,
Len,
Literal,
OrderedWindowExpr,
Over,
OverOrdered,
RangeExpr,
RollingExpr,
RootSelector,
Sort,
SortBy,
StructExpr,
TernaryExpr,
WindowExpr,
col,
ternary_expr,
)
Expand All @@ -65,7 +65,8 @@
"Literal",
"NamedIR",
"OrderableAggExpr",
"OrderedWindowExpr",
"Over",
"OverOrdered",
"RangeExpr",
"RenameAlias",
"RollingExpr",
Expand All @@ -75,7 +76,6 @@
"SortBy",
"StructExpr",
"TernaryExpr",
"WindowExpr",
"aggregation",
"boolean",
"categorical",
Expand Down
28 changes: 4 additions & 24 deletions narwhals/_plan/expressions/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co
from narwhals._plan.expressions.functions import MapBatches # noqa: F401
from narwhals._plan.expressions.literal import LiteralValue
from narwhals._plan.expressions.window import Window
from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions
from narwhals._plan.schema import FrozenSchema
from narwhals.dtypes import DType
Expand All @@ -56,14 +55,14 @@
"FunctionExpr",
"Len",
"Literal",
"Over",
"RollingExpr",
"RootSelector",
"SelectorIR",
"Sort",
"SortBy",
"StructExpr",
"TernaryExpr",
"WindowExpr",
"col",
"ternary_expr",
]
Expand Down Expand Up @@ -353,9 +352,7 @@ def iter_output_name(self) -> t.Iterator[ExprIR]:
yield from self.expr.iter_output_name()


class WindowExpr(
ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over")
):
class Over(ExprIR, child=("expr", "partition_by")):
"""A fully specified `.over()`, that occurred after another expression.

Related:
Expand All @@ -364,11 +361,10 @@ class WindowExpr(
- https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876
"""

__slots__ = ("expr", "partition_by", "options") # noqa: RUF023
__slots__ = ("expr", "partition_by")
expr: ExprIR
"""For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`."""
partition_by: Seq[ExprIR]
options: Window

def __repr__(self) -> str:
return f"{self.expr!r}.over({list(self.partition_by)!r})"
Expand All @@ -377,17 +373,12 @@ def iter_output_name(self) -> t.Iterator[ExprIR]:
yield from self.expr.iter_output_name()


class OrderedWindowExpr(
WindowExpr,
child=("expr", "partition_by", "order_by"),
config=ExprIROptions.renamed("over_ordered"),
):
class OverOrdered(Over, child=("expr", "partition_by", "order_by")):
__slots__ = ("order_by", "sort_options")
expr: ExprIR
partition_by: Seq[ExprIR]
order_by: Seq[ExprIR]
sort_options: SortOptions
options: Window

def __repr__(self) -> str:
order = self.order_by
Expand All @@ -397,17 +388,6 @@ def __repr__(self) -> str:
args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}"
return f"{self.expr!r}.over({args})"

# TODO @dangotbanned: Update to align with https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5
def iter_root_names(self) -> t.Iterator[ExprIR]:
# NOTE: `order_by` ~~is~~ was never considered in `polars`
# To match that behavior for `root_names` - but still expand in all other cases
# - this little escape hatch exists
# https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86
yield from self.expr.iter_left()
for e in self.partition_by:
yield from e.iter_left()
yield self

def order_by_names(self) -> Iterator[str]:
"""Yield the names resolved from expanding `order_by`.

Expand Down
64 changes: 27 additions & 37 deletions narwhals/_plan/expressions/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,42 @@

from typing import TYPE_CHECKING

from narwhals._plan._guards import is_function_expr, is_window_expr
from narwhals._plan._immutable import Immutable
from narwhals._plan._guards import is_function_expr, is_over
from narwhals._plan.exceptions import (
over_elementwise_error as elementwise_error,
over_nested_error as nested_error,
over_row_separable_error as row_separable_error,
)
from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr
from narwhals._plan.expressions.expr import Over, OverOrdered
from narwhals._plan.options import SortOptions

if TYPE_CHECKING:
from narwhals._plan.expressions import ExprIR
from narwhals._plan.typing import Seq
from narwhals.exceptions import InvalidOperationError


class Window(Immutable):
"""Renamed from `WindowType` https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139."""


class Over(Window):
@staticmethod
def _validate_over(
expr: ExprIR,
partition_by: Seq[ExprIR],
order_by: Seq[ExprIR] = (),
sort_options: SortOptions | None = None,
/,
) -> ValueError | None:
if is_window_expr(expr):
return nested_error(expr, partition_by, order_by, sort_options)
if is_function_expr(expr):
if expr.options.is_elementwise():
return elementwise_error(expr, partition_by, order_by, sort_options)
if expr.options.is_row_separable():
return row_separable_error(expr, partition_by, order_by, sort_options)
return None


def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr:
if err := Over._validate_over(expr, partition_by):
def _validate_over(
expr: ExprIR,
partition_by: Seq[ExprIR],
order_by: Seq[ExprIR] = (),
sort_options: SortOptions | None = None,
/,
) -> InvalidOperationError | None:
if is_over(expr):
return nested_error(expr, partition_by, order_by, sort_options)
if is_function_expr(expr):
if expr.options.is_elementwise():
return elementwise_error(expr, partition_by, order_by, sort_options)
if expr.options.is_row_separable():
return row_separable_error(expr, partition_by, order_by, sort_options)
return None


def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> Over:
if err := _validate_over(expr, partition_by):
raise err
return WindowExpr(expr=expr, partition_by=partition_by, options=Over())
return Over(expr=expr, partition_by=partition_by)


def over_ordered(
Expand All @@ -54,14 +48,10 @@ def over_ordered(
*,
descending: bool = False,
nulls_last: bool = False,
) -> OrderedWindowExpr:
) -> OverOrdered:
sort_options = SortOptions(descending=descending, nulls_last=nulls_last)
if err := Over._validate_over(expr, partition_by, order_by, sort_options):
if err := _validate_over(expr, partition_by, order_by, sort_options):
raise err
return OrderedWindowExpr(
expr=expr,
partition_by=partition_by,
order_by=order_by,
sort_options=sort_options,
options=Over(),
return OverOrdered(
expr=expr, partition_by=partition_by, order_by=order_by, sort_options=sort_options
)
2 changes: 1 addition & 1 deletion narwhals/_plan/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def as_selector(self) -> Selector:


def iter_root_names(expr: ir.ExprIR, /) -> Iterator[str]:
yield from (e.name for e in expr.iter_root_names() if isinstance(e, ir.Column))
yield from (e.name for e in expr.iter_left() if isinstance(e, ir.Column))


def root_name_first(expr: ir.ExprIR, /) -> str:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ typing = [ # keep some of these pinned and bump periodically so there's fewer s
"pyright",
"pyarrow-stubs==19.2",
"sqlframe",
"polars==1.34.0",
"polars>=1.36.0",
"uv",
"narwhals[ibis]",
]
Expand Down
2 changes: 1 addition & 1 deletion tests/plan/expr_expansion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def test_expand_function_expr_multi_invalid(df_1: Frame) -> None:
def test_over_order_by_names() -> None:
expr = nwp.col("a").first().over(order_by=ncs.string())
e_ir = expr._ir
assert isinstance(e_ir, ir.OrderedWindowExpr)
assert isinstance(e_ir, ir.OverOrdered)
with pytest.raises(
InvalidOperationError,
match=re_compile(
Expand Down
2 changes: 1 addition & 1 deletion tests/plan/expr_parsing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_over_invalid() -> None:

# NOTE: This version isn't elementwise
expr_ir = nwp.col("a").fill_null(strategy="backward").over("b")._ir
assert isinstance(expr_ir, ir.WindowExpr)
assert isinstance(expr_ir, ir.Over)
assert isinstance(expr_ir.expr, ir.FunctionExpr)
assert isinstance(expr_ir.expr.function, F.FillNullWithStrategy)

Expand Down
Loading
Loading