diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index d6e2f4cdaf..64eef0c4e0 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -73,12 +73,7 @@ Combination: TypeAlias = Union[ - ir.SortBy, - ir.BinaryExpr, - ir.TernaryExpr, - ir.Filter, - ir.OrderedWindowExpr, - ir.WindowExpr, + ir.SortBy, ir.BinaryExpr, ir.TernaryExpr, ir.Filter, ir.OverOrdered, ir.Over ] @@ -268,11 +263,11 @@ def _expand_only(self, origin: ExprIR, child: ExprIR, /) -> ExprIR: # TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]: changes: dict[str, Any] = {} - if isinstance(origin, (ir.WindowExpr, ir.Filter, ir.SortBy)): - if isinstance(origin, ir.WindowExpr): + if isinstance(origin, (ir.Over, ir.Filter, ir.SortBy)): + if isinstance(origin, ir.Over): if partition_by := origin.partition_by: changes["partition_by"] = tuple(self._expand_inner(partition_by)) - if isinstance(origin, ir.OrderedWindowExpr): + if isinstance(origin, ir.OverOrdered): changes["order_by"] = tuple(self._expand_inner(origin.order_by)) elif isinstance(origin, ir.SortBy): changes["by"] = tuple(self._expand_inner(origin.by)) @@ -355,7 +350,7 @@ def _expand_function_expr( ir.BinaryExpr, ir.TernaryExpr, ir.Filter, - ir.OrderedWindowExpr, - ir.WindowExpr, + ir.OverOrdered, + ir.Over, ) """more than one (direct) child and those can be nested.""" diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 80b6a0a8d4..b82ee1e6c5 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -155,14 +155,6 @@ def iter_right(self) -> Iterator[ExprIR]: for node in reversed(child): # pragma: no cover yield from node.iter_right() - def iter_root_names(self) -> Iterator[ExprIR]: - """Override for different iteration behavior in `ExprIR.meta.root_names`. - - Note: - Identical to `iter_left` by default. - """ - yield from self.iter_left() - def iter_output_name(self) -> Iterator[ExprIR]: """Override for different iteration behavior in `ExprIR.meta.output_name`. diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 30503c55c3..b06bfbaa33 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -113,8 +113,8 @@ def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSerie ) -def is_window_expr(obj: Any) -> TypeIs[ir.WindowExpr]: - return isinstance(obj, _ir().WindowExpr) +def is_over(obj: Any) -> TypeIs[ir.Over]: + return isinstance(obj, _ir().Over) def is_function_expr(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index 1a4d40fb7d..de08e8e86f 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -9,7 +9,7 @@ is_aggregation, is_binary_expr, is_function_expr, - is_window_expr, + is_over, ) from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.common import replace @@ -49,7 +49,7 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: [discord-0]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384807793512677398 """ if ( - is_window_expr(window) + is_over(window) and is_function_expr(window.expr) and window.expr.options.is_elementwise() ): @@ -75,7 +75,7 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: [discord-2]: https://discord.com/channels/1235257048170762310/1383078215303696544/1384869107203047588 """ if ( - is_window_expr(window) + is_over(window) and is_binary_expr(window.expr) and (is_aggregation(window.expr.right)) ): diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index 0a95b18a84..74f26de213 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -539,7 +539,7 @@ def skew(self, node: FExpr[F.Skew], frame: Frame, name: str) -> Scalar: def over( self, - node: ir.WindowExpr, + node: ir.Over, frame: Frame, name: str, *, @@ -558,7 +558,7 @@ def over( return self.from_series(results.get_column(name)) def over_ordered( - self, node: ir.OrderedWindowExpr, frame: Frame, name: str + self, node: ir.OverOrdered, frame: Frame, name: str ) -> Self | Scalar: by = node.order_by_names() indices = fn.sort_indices(frame.native, *by, options=node.sort_options) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 39a7a912a1..dbe2394395 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -97,11 +97,11 @@ def is_not_null( self, node: FunctionExpr[IsNotNull], frame: FrameT_contra, name: str ) -> Self: ... def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... - def over(self, node: ir.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... + def over(self, node: ir.Over, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` # e.g. `nw.col("a").first().over(order_by="b")` def over_ordered( - self, node: ir.OrderedWindowExpr, frame: FrameT_contra, name: str + self, node: ir.OverOrdered, frame: FrameT_contra, name: str ) -> Self | CompliantScalar[FrameT_contra, SeriesT_co]: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... def rolling_expr( diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 53aaa49f67..494088a644 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -127,7 +127,7 @@ def binary_expr_length_changing_error( # TODO @dangotbanned: Use arguments in error message def over_nested_error( - expr: ir.WindowExpr, # noqa: ARG001 + expr: ir.Over, # noqa: ARG001 partition_by: Seq[ir.ExprIR], # noqa: ARG001 order_by: Seq[ir.ExprIR] = (), # noqa: ARG001 sort_options: SortOptions | None = None, # noqa: ARG001 @@ -159,7 +159,7 @@ def over_row_separable_error( def over_order_by_names_error( - expr: ir.OrderedWindowExpr, by: ir.ExprIR + expr: ir.OverOrdered, by: ir.ExprIR ) -> InvalidOperationError: if by.meta.is_column_selection(allow_aliasing=True): # narwhals dev error diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 56e4f5d4ce..c5422f1931 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -32,7 +32,8 @@ InvertSelector, Len, Literal, - OrderedWindowExpr, + Over, + OverOrdered, RangeExpr, RollingExpr, RootSelector, @@ -40,7 +41,6 @@ SortBy, StructExpr, TernaryExpr, - WindowExpr, col, ternary_expr, ) @@ -65,7 +65,8 @@ "Literal", "NamedIR", "OrderableAggExpr", - "OrderedWindowExpr", + "Over", + "OverOrdered", "RangeExpr", "RenameAlias", "RollingExpr", @@ -75,7 +76,6 @@ "SortBy", "StructExpr", "TernaryExpr", - "WindowExpr", "aggregation", "boolean", "categorical", diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 87271ea1db..cbef1b12ed 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -39,7 +39,6 @@ from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expressions.functions import MapBatches # noqa: F401 from narwhals._plan.expressions.literal import LiteralValue - from narwhals._plan.expressions.window import Window from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType @@ -56,6 +55,7 @@ "FunctionExpr", "Len", "Literal", + "Over", "RollingExpr", "RootSelector", "SelectorIR", @@ -63,7 +63,6 @@ "SortBy", "StructExpr", "TernaryExpr", - "WindowExpr", "col", "ternary_expr", ] @@ -353,9 +352,7 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() -class WindowExpr( - ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") -): +class Over(ExprIR, child=("expr", "partition_by")): """A fully specified `.over()`, that occurred after another expression. Related: @@ -364,11 +361,10 @@ class WindowExpr( - https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/mod.rs#L840-L876 """ - __slots__ = ("expr", "partition_by", "options") # noqa: RUF023 + __slots__ = ("expr", "partition_by") expr: ExprIR """For lazy backends, this should be the only place we allow `rolling_*`, `cum_*`.""" partition_by: Seq[ExprIR] - options: Window def __repr__(self) -> str: return f"{self.expr!r}.over({list(self.partition_by)!r})" @@ -377,17 +373,12 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() -class OrderedWindowExpr( - WindowExpr, - child=("expr", "partition_by", "order_by"), - config=ExprIROptions.renamed("over_ordered"), -): +class OverOrdered(Over, child=("expr", "partition_by", "order_by")): __slots__ = ("order_by", "sort_options") expr: ExprIR partition_by: Seq[ExprIR] order_by: Seq[ExprIR] sort_options: SortOptions - options: Window def __repr__(self) -> str: order = self.order_by @@ -397,17 +388,6 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" - # TODO @dangotbanned: Update to align with https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 - def iter_root_names(self) -> t.Iterator[ExprIR]: - # NOTE: `order_by` ~~is~~ was never considered in `polars` - # To match that behavior for `root_names` - but still expand in all other cases - # - this little escape hatch exists - # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - yield self - def order_by_names(self) -> Iterator[str]: """Yield the names resolved from expanding `order_by`. diff --git a/narwhals/_plan/expressions/window.py b/narwhals/_plan/expressions/window.py index 772af084f0..2fb1f2a397 100644 --- a/narwhals/_plan/expressions/window.py +++ b/narwhals/_plan/expressions/window.py @@ -2,48 +2,42 @@ from typing import TYPE_CHECKING -from narwhals._plan._guards import is_function_expr, is_window_expr -from narwhals._plan._immutable import Immutable +from narwhals._plan._guards import is_function_expr, is_over from narwhals._plan.exceptions import ( over_elementwise_error as elementwise_error, over_nested_error as nested_error, over_row_separable_error as row_separable_error, ) -from narwhals._plan.expressions.expr import OrderedWindowExpr, WindowExpr +from narwhals._plan.expressions.expr import Over, OverOrdered from narwhals._plan.options import SortOptions if TYPE_CHECKING: from narwhals._plan.expressions import ExprIR from narwhals._plan.typing import Seq + from narwhals.exceptions import InvalidOperationError -class Window(Immutable): - """Renamed from `WindowType` https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/options/mod.rs#L139.""" - - -class Over(Window): - @staticmethod - def _validate_over( - expr: ExprIR, - partition_by: Seq[ExprIR], - order_by: Seq[ExprIR] = (), - sort_options: SortOptions | None = None, - /, - ) -> ValueError | None: - if is_window_expr(expr): - return nested_error(expr, partition_by, order_by, sort_options) - if is_function_expr(expr): - if expr.options.is_elementwise(): - return elementwise_error(expr, partition_by, order_by, sort_options) - if expr.options.is_row_separable(): - return row_separable_error(expr, partition_by, order_by, sort_options) - return None - - -def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> WindowExpr: - if err := Over._validate_over(expr, partition_by): +def _validate_over( + expr: ExprIR, + partition_by: Seq[ExprIR], + order_by: Seq[ExprIR] = (), + sort_options: SortOptions | None = None, + /, +) -> InvalidOperationError | None: + if is_over(expr): + return nested_error(expr, partition_by, order_by, sort_options) + if is_function_expr(expr): + if expr.options.is_elementwise(): + return elementwise_error(expr, partition_by, order_by, sort_options) + if expr.options.is_row_separable(): + return row_separable_error(expr, partition_by, order_by, sort_options) + return None + + +def over(expr: ExprIR, partition_by: Seq[ExprIR], /) -> Over: + if err := _validate_over(expr, partition_by): raise err - return WindowExpr(expr=expr, partition_by=partition_by, options=Over()) + return Over(expr=expr, partition_by=partition_by) def over_ordered( @@ -54,14 +48,10 @@ def over_ordered( *, descending: bool = False, nulls_last: bool = False, -) -> OrderedWindowExpr: +) -> OverOrdered: sort_options = SortOptions(descending=descending, nulls_last=nulls_last) - if err := Over._validate_over(expr, partition_by, order_by, sort_options): + if err := _validate_over(expr, partition_by, order_by, sort_options): raise err - return OrderedWindowExpr( - expr=expr, - partition_by=partition_by, - order_by=order_by, - sort_options=sort_options, - options=Over(), + return OverOrdered( + expr=expr, partition_by=partition_by, order_by=order_by, sort_options=sort_options ) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 5938795a13..bbfd2d2a28 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -89,7 +89,7 @@ def as_selector(self) -> Selector: def iter_root_names(expr: ir.ExprIR, /) -> Iterator[str]: - yield from (e.name for e in expr.iter_root_names() if isinstance(e, ir.Column)) + yield from (e.name for e in expr.iter_left() if isinstance(e, ir.Column)) def root_name_first(expr: ir.ExprIR, /) -> str: diff --git a/pyproject.toml b/pyproject.toml index 154b7616bc..44b1ac2ba3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ typing = [ # keep some of these pinned and bump periodically so there's fewer s "pyright", "pyarrow-stubs==19.2", "sqlframe", - "polars==1.34.0", + "polars>=1.36.0", "uv", "narwhals[ibis]", ] diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index b6e640adf8..c230dbcb2c 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -744,7 +744,7 @@ def test_expand_function_expr_multi_invalid(df_1: Frame) -> None: def test_over_order_by_names() -> None: expr = nwp.col("a").first().over(order_by=ncs.string()) e_ir = expr._ir - assert isinstance(e_ir, ir.OrderedWindowExpr) + assert isinstance(e_ir, ir.OverOrdered) with pytest.raises( InvalidOperationError, match=re_compile( diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 34ab20a1c0..3abc09fe9f 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -244,7 +244,7 @@ def test_over_invalid() -> None: # NOTE: This version isn't elementwise expr_ir = nwp.col("a").fill_null(strategy="backward").over("b")._ir - assert isinstance(expr_ir, ir.WindowExpr) + assert isinstance(expr_ir, ir.Over) assert isinstance(expr_ir.expr, ir.FunctionExpr) assert isinstance(expr_ir.expr.function, F.FillNullWithStrategy) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index fdcdb566c5..70f3445e58 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -12,7 +12,6 @@ rewrite_binary_agg_over, rewrite_elementwise_over, ) -from narwhals._plan.expressions.window import Over from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_expr_ir_equal, named_ir @@ -38,11 +37,10 @@ def schema_2() -> dict[str, DType]: } -def _to_window_expr(into_expr: IntoExpr, *partition_by: IntoExpr) -> ir.WindowExpr: - return ir.WindowExpr( +def _over(into_expr: IntoExpr, *partition_by: IntoExpr) -> ir.Over: + return ir.Over( expr=_parse.parse_into_expr_ir(into_expr), partition_by=_parse.parse_into_seq_of_expr_ir(*partition_by), - options=Over(), ) @@ -55,7 +53,7 @@ def test_rewrite_elementwise_over_simple(schema_2: dict[str, DType]) -> None: # Later, that error might not be needed if we can do this rewrite. # If you're here because of a "Did not raise" - just replace everything with the (previously) erroring expr. expected = nwp.col("a").sum().over("b").abs() - before = _to_window_expr(nwp.col("a").sum().abs(), "b").to_narwhals() + before = _over(nwp.col("a").sum().abs(), "b").to_narwhals() assert_expr_ir_equal(before, "col('a').sum().abs().over([col('b')])") actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) assert len(actual) == 1 @@ -67,9 +65,7 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: nwp.col("b").last().over("d").replace_strict({1: 2}), nwp.col("c").last().over("d").replace_strict({1: 2}), ) - before = _to_window_expr( - nwp.col("b", "c").last().replace_strict({1: 2}), "d" - ).to_narwhals() + before = _over(nwp.col("b", "c").last().replace_strict({1: 2}), "d").to_narwhals() assert_expr_ir_equal( before, "ncs.by_name('b', 'c', require_all=True).last().replace_strict().over([col('d')])", @@ -97,14 +93,14 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: nwp.col("a"), nwp.col("b").cast(nw.String), ( - _to_window_expr(nwp.col("c").max().alias("x").fill_null(50), "a") + _over(nwp.col("c").max().alias("x").fill_null(50), "a") .to_narwhals() .alias("x2") ), ~(nwp.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), ncs.string().str.contains("some").name.suffix("_some"), ( - _to_window_expr(nwp.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") + _over(nwp.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") .to_narwhals() .name.to_uppercase() ), diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 5385c65e17..dbb48de5ac 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -19,23 +19,21 @@ if POLARS_VERSION >= (1, 0): # https://github.com/pola-rs/polars/pull/16743 if POLARS_VERSION >= (1, 36): # pragma: no cover - # TODO @dangotbanned: Update special-casing in `OrderedWindowExpr` - # https://github.com/pola-rs/polars/pull/25117/files#diff-45d1f22172e291bd4a5ce36d1fb8233698394f9590bcf11382b9c99b5449fff5 - marks: tuple[pytest.MarkDecorator, ...] = ( + marks: tuple[pytest.MarkDecorator, ...] = () + else: # pragma: no cover + marks = ( pytest.mark.xfail( reason=( - "`polars==1.36.0b1` now considers `order_by` in `root_names`\n" + "`polars>=1.36.0b1` now considers `order_by` in `root_names`\n" r"https://github.com/pola-rs/polars/pull/25117" ), raises=AssertionError, ), ) - else: # pragma: no cover - marks = () OVER_CASE = pytest.param( nwp.col("a").last().over("b", order_by="c"), pl.col("a").last().over("b", order_by="c"), - ["a", "b"], + ["a", "b", "c"], marks=marks, ) else: # pragma: no cover