diff --git a/docs/how_it_works.md b/docs/how_it_works.md index 64e071fa88..2030a2db3c 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -272,7 +272,104 @@ print((pn.col("a") + 1).mean()) For simple aggregations, Narwhals can just look at `_depth` and `function_name` and figure out which (efficient) elementary operation this corresponds to in pandas. -## Broadcasting +## Expression Metadata + +Let's try printing out a few expressions to the console to see what they show us: + +```python exec="1" result="python" session="metadata" source="above" +import narwhals as nw + +print(nw.col("a")) +print(nw.col("a").mean()) +print(nw.col("a").mean().over("b")) +``` + +Note how they tell us something about their metadata. This section is all about +making sense of what that all means, what the rules are, and what it enables. + +### Expression kinds + +Each Narwhals expression can be of one of the following kinds: + +- `LITERAL`: expressions which correspond to literal values, such as the `3` in `nw.col('a')+3`. +- `AGGREGATION`: expressions which reduce a column to a single value (e.g. `nw.col('a').mean()`). +- `TRANSFORM`: expressions which don't change length (e.g. `nw.col('a').abs()`). +- `WINDOW`: like `TRANSFORM`, but the last operation is a (row-order-dependent) + window function (`rolling_*`, `cum_*`, `diff`, `shift`, `is_*_distinct`). +- `FILTRATION`: expressions which change length but don't + aggregate (e.g. `nw.col('a').drop_nulls()`). + +For example: + + - `nw.col('a')` is not order-dependent, so it's `TRANSFORM`. + - `nw.col('a').abs()` is not order-dependent, so it's a `TRANSFORM`. + - `nw.col('a').cum_sum()`'s last operation is `cum_sum`, so it's `WINDOW`. + - `nw.col('a').cum_sum() + 1`'s last operation is `__add__`, and it preserves + the input dataframe's length, so it's a `TRANSFORM`. + +How these change depends on the operation. + +#### Chaining + +Say we have `expr.expr_method()`. How does `expr`'s `ExprMetadata` change? +This depends on `expr_method`. + +- Element-wise expressions such `abs`, `alias`, `cast`, `__invert__`, and + many more, preserve the input kind (unless `expr` is a `WINDOW`, in + which case it becomes a `TRANSFORM`. This is because for an expression + to be `WINDOW`, the last expression needs to be the order-dependent one). +- `rolling_*`, `cum_*`, `diff`, `shift`, `ewm_mean`, and `is_*_distinct` + are window functions and result in `WINDOW`. +- `mean`, `std`, `median`, and other aggregations result in `AGGREGATION`, + and can only be applied to `TRANSFORM` and `WINDOW`. +- `drop_nulls` and `filter` result in `FILTRATION`, and can only be applied + to `TRANSFORM` and `WINDOW`. +- `over` always results in `TRANSFORM`. This is a bit more complicated, + so we elaborate on it in the ["You open a window ..."](#you-open-a-window-to-another-window-to-another-window-to-another-window). + +#### Binary operations (e.g. `nw.col('a') + nw.col('b')`) + +How do expression kinds change under binary operations? For example, +if we do `expr1 + expr2`, then what can we say about the output kind? +The rules are: + +- If both are `LITERAL`, then the output is `LITERAL`. +- If one is a `FILTRATION`, then: + + - if the other is `LITERAL` or `AGGREGATION`, then the output is `FILTRATION`. + - else, we raise an error. + +- If one is `TRANSFORM` or `WINDOW` and the other is not `FILTRATION`, + then the output is `TRANSFORM`. +- If one is `AGGREGATION` and the other is `LITERAL` or `AGGREGATION`, + the output is `AGGREGATION`. + +For n-ary operations such as `nw.sum_horizontal`, the above logic is +extended across inputs. For example, `nw.sum_horizontal(expr1, expr2, expr3)` +is `LITERAL` if all of `expr1`, `expr2`, and `expr3` are. + +### "You open a window to another window to another window to another window" + +When we print out an expression, in addition to the expression kind, +we also see `window_kind`. There are four window kinds: + +- `NONE`: non-order-dependent operations, like `.abs()` or `.mean()`. +- `CLOSEABLE`: expression where the last operation is order-dependent. For + example, `nw.col('a').diff()`. +- `UNCLOSEABLE`: expression where some operation is order-dependent but + the order-dependent operation wasn't the last one. For example, + `nw.col('a').diff().abs()`. +- `CLOSED`: expression contains `over` at some point, and any order-dependent + operation was immediately followed by `over(order_by=...)`. + +When working with `DataFrame`s, row order is well-defined, as the dataframes +are assumed to be eager and in-memory. Therefore, it's allowed to work +with all window kinds. + +When working with `LazyFrame`s, on the other hand, row order is undefined. +Therefore, window kinds must either be `NONE` or `CLOSED`. + +### Broadcasting When performing comparisons between columns and aggregations or scalars, we operate as if the aggregation or scalar was broadcasted to the length of the whole column. For example, if we @@ -282,14 +379,7 @@ with values `[-1, 0, 1]`. Different libraries do broadcasting differently. SQL-like libraries require an empty window function for expressions (e.g. `a - sum(a) over ()`), Polars does its own broadcasting of -length-1 Series, and pandas does its own broadcasting of scalars. Narwhals keeps track of -when to trigger a broadcast by tracking the `ExprKind` of each expression. `ExprKind` is an -`Enum` with four variants: - -- `TRANSFORM`: expressions which don't change length (e.g. `nw.col('a').abs()`). -- `AGGREGATION`: expressions which reduce a column to a single value (e.g. `nw.col('a').mean()`). -- `CHANGE_LENGTH`: expressions which change length but don't necessarily aggregate (e.g. `nw.col('a').drop_nulls()`). -- `LITERAL`: expressions which correspond to literal values, such as the `3` in `nw.col('a')+3`. +length-1 Series, and pandas does its own broadcasting of scalars. Narwhals triggers a broadcast in these situations: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index f82ef55839..168c125b06 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -121,7 +121,7 @@ class ExprKind(Enum): - LITERAL vs LITERAL -> LITERAL - FILTRATION vs (LITERAL | AGGREGATION) -> FILTRATION - FILTRATION vs (FILTRATION | TRANSFORM | WINDOW) -> raise - - (TRANSFORM | WINDOW) vs (LITERAL | AGGREGATION) -> TRANSFORM + - (TRANSFORM | WINDOW) vs (...) -> TRANSFORM - AGGREGATION vs (LITERAL | AGGREGATION) -> AGGREGATION """ @@ -191,14 +191,48 @@ def is_multi_output( return expansion_kind in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED} +class WindowKind(Enum): + """Describe what kind of window the expression contains.""" + + NONE = auto() + """e.g. `nw.col('a').abs()`, no windows.""" + + CLOSEABLE = auto() + """e.g. `nw.col('a').cum_sum()` - can be closed if immediately followed by `over(order_by=...)`.""" + + UNCLOSEABLE = auto() + """e.g. `nw.col('a').cum_sum().abs()` - the window function (`cum_sum`) wasn't immediately followed by + `over(order_by=...)`, and so the window is uncloseable. + + Uncloseable windows can be used freely in `nw.DataFrame`, but not in `nw.LazyFrame` where + row-order is undefined.""" + + CLOSED = auto() + """e.g. `nw.col('a').cum_sum().over(order_by='i')`.""" + + def is_open(self) -> bool: + return self in {WindowKind.UNCLOSEABLE, WindowKind.CLOSEABLE} + + def is_closed(self) -> bool: + return self is WindowKind.CLOSED + + def is_uncloseable(self) -> bool: + return self is WindowKind.UNCLOSEABLE + + class ExprMetadata: - __slots__ = ("_expansion_kind", "_kind", "_n_open_windows") + __slots__ = ("_expansion_kind", "_kind", "_window_kind") def __init__( - self, kind: ExprKind, /, *, n_open_windows: int, expansion_kind: ExpansionKind + self, + kind: ExprKind, + /, + *, + window_kind: WindowKind, + expansion_kind: ExpansionKind, ) -> None: self._kind: ExprKind = kind - self._n_open_windows = n_open_windows + self._window_kind = window_kind self._expansion_kind = expansion_kind def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover @@ -206,15 +240,15 @@ def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no c raise TypeError(msg) def __repr__(self) -> str: - return f"ExprMetadata(kind: {self._kind}, n_open_windows: {self._n_open_windows}, expansion_kind: {self._expansion_kind})" + return f"ExprMetadata(kind: {self._kind}, window_kind: {self._window_kind}, expansion_kind: {self._expansion_kind})" @property def kind(self) -> ExprKind: return self._kind @property - def n_open_windows(self) -> int: - return self._n_open_windows + def window_kind(self) -> WindowKind: + return self._window_kind @property def expansion_kind(self) -> ExpansionKind: @@ -223,50 +257,77 @@ def expansion_kind(self) -> ExpansionKind: def with_kind(self, kind: ExprKind, /) -> ExprMetadata: """Change metadata kind, leaving all other attributes the same.""" return ExprMetadata( - kind, n_open_windows=self._n_open_windows, expansion_kind=self._expansion_kind + kind, + window_kind=self._window_kind, + expansion_kind=self._expansion_kind, ) - def with_extra_open_window(self) -> ExprMetadata: - """Increment `n_open_windows` leaving other attributes the same.""" + def with_uncloseable_window(self) -> ExprMetadata: + """Add uncloseable window, leaving other attributes the same.""" + if self._window_kind is WindowKind.CLOSED: # pragma: no cover + msg = "Unreachable code, please report a bug." + raise AssertionError(msg) return ExprMetadata( self.kind, - n_open_windows=self._n_open_windows + 1, + window_kind=WindowKind.UNCLOSEABLE, + expansion_kind=self._expansion_kind, + ) + + def with_kind_and_closeable_window(self, kind: ExprKind, /) -> ExprMetadata: + """Change metadata kind and add closeable window. + + If we already have an uncloseable window, the window stays uncloseable. + """ + if self._window_kind is WindowKind.NONE: + window_kind = WindowKind.CLOSEABLE + elif self._window_kind is WindowKind.CLOSED: # pragma: no cover + msg = "Unreachable code, please report a bug." + raise AssertionError(msg) + else: + window_kind = WindowKind.UNCLOSEABLE + return ExprMetadata( + kind, + window_kind=window_kind, expansion_kind=self._expansion_kind, ) - def with_kind_and_extra_open_window(self, kind: ExprKind, /) -> ExprMetadata: - """Change metadata kind and increment `n_open_windows`.""" + def with_kind_and_uncloseable_window(self, kind: ExprKind, /) -> ExprMetadata: + """Change metadata kind and set window kind to uncloseable.""" return ExprMetadata( kind, - n_open_windows=self._n_open_windows + 1, + window_kind=WindowKind.UNCLOSEABLE, expansion_kind=self._expansion_kind, ) @staticmethod - def simple_selector() -> ExprMetadata: + def selector_single() -> ExprMetadata: # e.g. `nw.col('a')`, `nw.nth(0)` return ExprMetadata( - ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ExprKind.TRANSFORM, + window_kind=WindowKind.NONE, + expansion_kind=ExpansionKind.SINGLE, ) @staticmethod - def multi_output_selector_named() -> ExprMetadata: + def selector_multi_named() -> ExprMetadata: # e.g. `nw.col('a', 'b')` return ExprMetadata( - ExprKind.TRANSFORM, n_open_windows=0, expansion_kind=ExpansionKind.MULTI_NAMED + ExprKind.TRANSFORM, + window_kind=WindowKind.NONE, + expansion_kind=ExpansionKind.MULTI_NAMED, ) @staticmethod - def multi_output_selector_unnamed() -> ExprMetadata: + def selector_multi_unnamed() -> ExprMetadata: # e.g. `nw.all()` return ExprMetadata( ExprKind.TRANSFORM, - n_open_windows=0, + window_kind=WindowKind.NONE, expansion_kind=ExpansionKind.MULTI_UNNAMED, ) -def combine_metadata( +def combine_metadata( # noqa: PLR0915 *args: IntoExpr | object | None, str_as_lit: bool, allow_multi_output: bool, @@ -285,8 +346,10 @@ def combine_metadata( has_transforms_or_windows = False has_aggregations = False has_literals = False - result_n_open_windows = 0 result_expansion_kind = ExpansionKind.SINGLE + has_closeable_windows = False + has_uncloseable_windows = False + has_closed_windows = False for i, arg in enumerate(args): if isinstance(arg, str) and not str_as_lit: @@ -307,8 +370,6 @@ def combine_metadata( result_expansion_kind = resolve_expansion_kind( result_expansion_kind, arg._metadata.expansion_kind ) - if arg._metadata.n_open_windows: - result_n_open_windows += 1 kind = arg._metadata.kind if kind is ExprKind.AGGREGATION: has_aggregations = True @@ -322,6 +383,14 @@ def combine_metadata( msg = "unreachable code" raise AssertionError(msg) + window_kind = arg._metadata.window_kind + if window_kind is WindowKind.UNCLOSEABLE: + has_uncloseable_windows = True + elif window_kind is WindowKind.CLOSEABLE: + has_closeable_windows = True + elif window_kind is WindowKind.CLOSED: + has_closed_windows = True + if ( has_literals and not has_aggregations @@ -342,10 +411,15 @@ def combine_metadata( else: result_kind = ExprKind.AGGREGATION + if has_uncloseable_windows or has_closeable_windows: + result_window_kind = WindowKind.UNCLOSEABLE + elif has_closed_windows: + result_window_kind = WindowKind.CLOSED + else: + result_window_kind = WindowKind.NONE + return ExprMetadata( - result_kind, - n_open_windows=result_n_open_windows, - expansion_kind=result_expansion_kind, + result_kind, window_kind=result_window_kind, expansion_kind=result_expansion_kind ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index bd7a14edc8..7cd20abf61 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -2152,7 +2152,7 @@ def _extract_compliant(self: Self, arg: Any) -> Any: plx = self.__narwhals_namespace__() return plx.col(arg) if isinstance(arg, Expr): - if arg._metadata.n_open_windows > 0: + if arg._metadata._window_kind.is_open(): msg = ( "Order-dependent expressions are not supported for use in LazyFrame.\n\n" "Hints:\n" diff --git a/narwhals/expr.py b/narwhals/expr.py index c996a6dd67..49c80f316b 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -9,6 +9,7 @@ from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import WindowKind from narwhals._expression_parsing import apply_n_ary_operation from narwhals._expression_parsing import combine_metadata from narwhals._expression_parsing import combine_metadata_binary_op @@ -65,18 +66,52 @@ def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: self._metadata = metadata def _with_callable(self, to_compliant_expr: Callable[[Any], Any]) -> Self: - # Instantiate new Expr keeping metadata unchanged. + # Instantiate new Expr keeping metadata unchanged, unless + # it's a WINDOW, in which case make it a TRANSFORM. + if self._metadata.kind.is_window(): + # We had a window function, but it wasn't immediately followed by + # `over(order_by=...)` - it missed its chance, it's now forever uncloseable. + return self.__class__( + to_compliant_expr, + self._metadata.with_kind_and_uncloseable_window(ExprKind.TRANSFORM), + ) return self.__class__(to_compliant_expr, self._metadata) + def _with_aggregation(self, to_compliant_expr: Callable[[Any], Any]) -> Self: + if self._metadata.kind.is_scalar_like(): + msg = "Aggregations can't be applied to scalar-like expressions." + raise InvalidOperationError(msg) + return self.__class__( + to_compliant_expr, self._metadata.with_kind(ExprKind.AGGREGATION) + ) + + def _with_order_dependent_aggregation( + self, to_compliant_expr: Callable[[Any], Any] + ) -> Self: + if self._metadata.kind.is_scalar_like(): + msg = "Aggregations can't be applied to scalar-like expressions." + raise InvalidOperationError(msg) + return self.__class__( + to_compliant_expr, + self._metadata.with_kind_and_closeable_window(ExprKind.AGGREGATION), + ) + + def _with_filtration(self, to_compliant_expr: Callable[[Any], Any]) -> Self: + if self._metadata.kind.is_scalar_like(): + msg = "Length-changing can't be applied to scalar-like expressions." + raise InvalidOperationError(msg) + return self.__class__( + to_compliant_expr, self._metadata.with_kind(ExprKind.FILTRATION) + ) + def __repr__(self: Self) -> str: return f"Narwhals Expr\nmetadata: {self._metadata}\n" def _taxicab_norm(self: Self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. - return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs().sum(), - self._metadata.with_kind(ExprKind.AGGREGATION), + return self._with_aggregation( + lambda plx: self._to_compliant_expr(plx).abs().sum() ) # --- convert --- @@ -382,10 +417,7 @@ def any(self: Self) -> Self: | 0 True True | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).any(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).any()) def all(self: Self) -> Self: """Return whether all values in the column are `True`. @@ -406,10 +438,7 @@ def all(self: Self) -> Self: | 0 False True | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).all(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).all()) def ewm_mean( self: Self, @@ -532,10 +561,7 @@ def mean(self: Self) -> Self: | 0 0.0 4.0 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).mean(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).mean()) def median(self: Self) -> Self: """Get median value. @@ -559,10 +585,7 @@ def median(self: Self) -> Self: | 0 3.0 4.0 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).median(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).median()) def std(self: Self, *, ddof: int = 1) -> Self: """Get standard deviation. @@ -587,9 +610,8 @@ def std(self: Self, *, ddof: int = 1) -> Self: |0 17.79513 1.265789| └─────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), - self._metadata.with_kind(ExprKind.AGGREGATION), + return self._with_aggregation( + lambda plx: self._to_compliant_expr(plx).std(ddof=ddof) ) def var(self: Self, *, ddof: int = 1) -> Self: @@ -615,9 +637,8 @@ def var(self: Self, *, ddof: int = 1) -> Self: |0 316.666667 1.602222| └───────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), - self._metadata.with_kind(ExprKind.AGGREGATION), + return self._with_aggregation( + lambda plx: self._to_compliant_expr(plx).var(ddof=ddof) ) def map_batches( @@ -664,7 +685,7 @@ def map_batches( function=function, return_dtype=return_dtype ), # safest assumptions - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + self._metadata.with_kind_and_closeable_window(ExprKind.FILTRATION), ) def skew(self: Self) -> Self: @@ -686,10 +707,7 @@ def skew(self: Self) -> Self: | 0 0.0 1.472427 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).skew(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).skew()) def sum(self: Self) -> Expr: """Return the sum value. @@ -714,10 +732,7 @@ def sum(self: Self) -> Expr: |└────────┴────────┘| └───────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).sum(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).sum()) def min(self: Self) -> Self: """Returns the minimum value(s) from a column(s). @@ -738,10 +753,7 @@ def min(self: Self) -> Self: | 0 1 3 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).min(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).min()) def max(self: Self) -> Self: """Returns the maximum value(s) from a column(s). @@ -762,10 +774,7 @@ def max(self: Self) -> Self: | 0 20 100 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).max(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).max()) def arg_min(self: Self) -> Self: """Returns the index of the minimum value. @@ -786,9 +795,8 @@ def arg_min(self: Self) -> Self: |0 0 1| └───────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_min(), - self._metadata.with_kind_and_extra_open_window(ExprKind.AGGREGATION), + return self._with_order_dependent_aggregation( + lambda plx: self._to_compliant_expr(plx).arg_min() ) def arg_max(self: Self) -> Self: @@ -810,9 +818,8 @@ def arg_max(self: Self) -> Self: |0 1 0| └───────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_max(), - self._metadata.with_kind_and_extra_open_window(ExprKind.AGGREGATION), + return self._with_order_dependent_aggregation( + lambda plx: self._to_compliant_expr(plx).arg_max() ) def count(self: Self) -> Self: @@ -834,10 +841,7 @@ def count(self: Self) -> Self: | 0 3 2 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).count(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).count()) def n_unique(self: Self) -> Self: """Returns count of unique values. @@ -858,10 +862,7 @@ def n_unique(self: Self) -> Self: | 0 5 3 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).n_unique(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).n_unique()) def unique(self: Self) -> Self: """Return unique values of this expression. @@ -882,10 +883,7 @@ def unique(self: Self) -> Self: | 0 9 12 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).unique(), - self._metadata.with_kind(ExprKind.FILTRATION), - ) + return self._with_filtration(lambda plx: self._to_compliant_expr(plx).unique()) def abs(self: Self) -> Self: """Return absolute value of each element. @@ -941,7 +939,7 @@ def cum_sum(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def diff(self: Self) -> Self: @@ -988,7 +986,7 @@ def diff(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).diff(), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def shift(self: Self, n: int) -> Self: @@ -1038,7 +1036,7 @@ def shift(self: Self, n: int) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).shift(n), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def replace_strict( @@ -1128,7 +1126,7 @@ def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> S lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - self._metadata.with_extra_open_window(), + self._metadata.with_uncloseable_window(), ) # --- transform --- @@ -1358,10 +1356,7 @@ def arg_true(self: Self) -> Self: "See https://narwhals-dev.github.io/narwhals/backcompat/ for more information.\n" ) issue_deprecation_warning(msg, _version="1.23.0") - return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_true(), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), - ) + return self._with_filtration(lambda plx: self._to_compliant_expr(plx).arg_true()) def fill_null( self: Self, @@ -1491,9 +1486,8 @@ def drop_nulls(self: Self) -> Self: | └─────┘ | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).drop_nulls(), - self._metadata.with_kind(ExprKind.FILTRATION), + return self._with_filtration( + lambda plx: self._to_compliant_expr(plx).drop_nulls() ) def sample( @@ -1530,11 +1524,10 @@ def sample( "See https://narwhals-dev.github.io/narwhals/backcompat/ for more information.\n" ) issue_deprecation_warning(msg, _version="1.23.0") - return self.__class__( + return self._with_filtration( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed - ), - self._metadata.with_kind(ExprKind.FILTRATION), + ) ) def over( @@ -1594,16 +1587,25 @@ def over( raise ValueError(msg) kind = ExprKind.TRANSFORM - n_open_windows = self._metadata.n_open_windows + window_kind = self._metadata.window_kind + if window_kind.is_closed(): + msg = "Nested `over` statements are not allowed." + raise InvalidOperationError(msg) if flat_order_by is not None and self._metadata.kind.is_window(): - n_open_windows -= 1 - elif flat_order_by is not None and not n_open_windows: + # debug assertion, an open window should already have been set + # by the window function. If it's immediately followed by `over`, then the + # window gets closed. + assert window_kind.is_open() # noqa: S101 + elif flat_order_by is not None and not window_kind.is_open(): msg = "Cannot use `order_by` in `over` on expression which isn't order-dependent." raise InvalidOperationError(msg) current_meta = self._metadata + next_window_kind = ( + WindowKind.UNCLOSEABLE if window_kind.is_uncloseable() else WindowKind.CLOSED + ) next_meta = ExprMetadata( kind, - n_open_windows=n_open_windows, + window_kind=next_window_kind, expansion_kind=current_meta.expansion_kind, ) @@ -1688,9 +1690,8 @@ def null_count(self: Self) -> Self: | 0 1 2 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).null_count(), - self._metadata.with_kind(ExprKind.AGGREGATION), + return self._with_aggregation( + lambda plx: self._to_compliant_expr(plx).null_count() ) def is_first_distinct(self: Self) -> Self: @@ -1723,7 +1724,7 @@ def is_first_distinct(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_first_distinct(), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def is_last_distinct(self: Self) -> Self: @@ -1756,7 +1757,7 @@ def is_last_distinct(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_last_distinct(), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def quantile( @@ -1793,9 +1794,8 @@ def quantile( | 0 24.5 74.5 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation), - self._metadata.with_kind(ExprKind.AGGREGATION), + return self._with_aggregation( + lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation) ) def head(self: Self, n: int = 10) -> Self: @@ -1821,10 +1821,7 @@ def head(self: Self, n: int = 10) -> Self: "See https://narwhals-dev.github.io/narwhals/backcompat/ for more information.\n" ) issue_deprecation_warning(msg, _version="1.23.0") - return self.__class__( - lambda plx: self._to_compliant_expr(plx).head(n), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), - ) + return self._with_filtration(lambda plx: self._to_compliant_expr(plx).head(n)) def tail(self: Self, n: int = 10) -> Self: r"""Get the last `n` rows. @@ -1849,10 +1846,7 @@ def tail(self: Self, n: int = 10) -> Self: "See https://narwhals-dev.github.io/narwhals/backcompat/ for more information.\n" ) issue_deprecation_warning(msg, _version="1.23.0") - return self.__class__( - lambda plx: self._to_compliant_expr(plx).tail(n), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), - ) + return self._with_filtration(lambda plx: self._to_compliant_expr(plx).tail(n)) def round(self: Self, decimals: int = 0) -> Self: r"""Round underlying floating point data by `decimals` digits. @@ -1915,10 +1909,7 @@ def len(self: Self) -> Self: | 0 2 1 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).len(), - self._metadata.with_kind(ExprKind.AGGREGATION), - ) + return self._with_aggregation(lambda plx: self._to_compliant_expr(plx).len()) def gather_every(self: Self, n: int, offset: int = 0) -> Self: r"""Take every nth value in the Series and return as new Series. @@ -1944,9 +1935,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: "See https://narwhals-dev.github.io/narwhals/backcompat/ for more information.\n" ) issue_deprecation_warning(msg, _version="1.23.0") - return self.__class__( - lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + return self._with_filtration( + lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset) ) # need to allow numeric typing @@ -2023,10 +2013,7 @@ def mode(self: Self) -> Self: | 0 1 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).mode(), - self._metadata.with_kind(ExprKind.FILTRATION), - ) + return self._with_filtration(lambda plx: self._to_compliant_expr(plx).mode()) def is_finite(self: Self) -> Self: """Returns boolean values indicating which original values are finite. @@ -2100,7 +2087,7 @@ def cum_count(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def cum_min(self: Self, *, reverse: bool = False) -> Self: @@ -2137,7 +2124,7 @@ def cum_min(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def cum_max(self: Self, *, reverse: bool = False) -> Self: @@ -2174,7 +2161,7 @@ def cum_max(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def cum_prod(self: Self, *, reverse: bool = False) -> Self: @@ -2211,7 +2198,7 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def rolling_sum( @@ -2277,7 +2264,7 @@ def rolling_sum( min_samples=min_samples_int, center=center, ), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def rolling_mean( @@ -2343,7 +2330,7 @@ def rolling_mean( min_samples=min_samples, center=center, ), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def rolling_var( @@ -2409,7 +2396,7 @@ def rolling_var( lambda plx: self._to_compliant_expr(plx).rolling_var( window_size=window_size, min_samples=min_samples, center=center, ddof=ddof ), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def rolling_std( @@ -2478,7 +2465,7 @@ def rolling_std( center=center, ddof=ddof, ), - self._metadata.with_kind_and_extra_open_window(ExprKind.WINDOW), + self._metadata.with_kind_and_closeable_window(ExprKind.WINDOW), ) def rank( diff --git a/narwhals/functions.py b/narwhals/functions.py index bd902716d3..8792844eba 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -15,6 +15,7 @@ from narwhals._expression_parsing import ExpansionKind from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import WindowKind from narwhals._expression_parsing import apply_n_ary_operation from narwhals._expression_parsing import check_expressions_preserve_length from narwhals._expression_parsing import combine_metadata @@ -1038,9 +1039,9 @@ def func(plx: Any) -> Any: return Expr( func, - ExprMetadata.simple_selector() + ExprMetadata.selector_single() if len(flat_names) == 1 - else ExprMetadata.multi_output_selector_named(), + else ExprMetadata.selector_multi_named(), ) @@ -1078,7 +1079,7 @@ def exclude(*names: str | Iterable[str]) -> Expr: def func(plx: Any) -> Any: return plx.exclude(exclude_names) - return Expr(func, ExprMetadata.multi_output_selector_unnamed()) + return Expr(func, ExprMetadata.selector_multi_unnamed()) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1118,9 +1119,9 @@ def func(plx: Any) -> Any: return Expr( func, - ExprMetadata.simple_selector() + ExprMetadata.selector_single() if len(flat_indices) == 1 - else ExprMetadata.multi_output_selector_unnamed(), + else ExprMetadata.selector_multi_unnamed(), ) @@ -1145,7 +1146,7 @@ def all_() -> Expr: | 1 4 0.246 | └──────────────────┘ """ - return Expr(lambda plx: plx.all(), ExprMetadata.multi_output_selector_unnamed()) + return Expr(lambda plx: plx.all(), ExprMetadata.selector_multi_unnamed()) # Add underscore so it doesn't conflict with builtin `len` @@ -1181,7 +1182,9 @@ def func(plx: Any) -> Any: return Expr( func, ExprMetadata( - ExprKind.AGGREGATION, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ExprKind.AGGREGATION, + window_kind=WindowKind.NONE, + expansion_kind=ExpansionKind.SINGLE, ), ) @@ -1653,7 +1656,9 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: return Expr( lambda plx: plx.lit(value, dtype), ExprMetadata( - ExprKind.LITERAL, n_open_windows=0, expansion_kind=ExpansionKind.SINGLE + ExprKind.LITERAL, + window_kind=WindowKind.NONE, + expansion_kind=ExpansionKind.SINGLE, ), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index d8603fe3c7..020e58fdbe 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -96,7 +96,7 @@ def by_dtype(*dtypes: DType | type[DType] | Iterable[DType | type[DType]]) -> Se flattened = flatten(dtypes) return Selector( lambda plx: plx.selectors.by_dtype(flattened), - ExprMetadata.multi_output_selector_unnamed(), + ExprMetadata.selector_multi_unnamed(), ) @@ -131,7 +131,7 @@ def matches(pattern: str) -> Selector: """ return Selector( lambda plx: plx.selectors.matches(pattern), - ExprMetadata.multi_output_selector_unnamed(), + ExprMetadata.selector_multi_unnamed(), ) @@ -162,7 +162,7 @@ def numeric() -> Selector: └─────┴─────┘ """ return Selector( - lambda plx: plx.selectors.numeric(), ExprMetadata.multi_output_selector_unnamed() + lambda plx: plx.selectors.numeric(), ExprMetadata.selector_multi_unnamed() ) @@ -197,7 +197,7 @@ def boolean() -> Selector: └──────────────────┘ """ return Selector( - lambda plx: plx.selectors.boolean(), ExprMetadata.multi_output_selector_unnamed() + lambda plx: plx.selectors.boolean(), ExprMetadata.selector_multi_unnamed() ) @@ -228,7 +228,7 @@ def string() -> Selector: └─────┘ """ return Selector( - lambda plx: plx.selectors.string(), ExprMetadata.multi_output_selector_unnamed() + lambda plx: plx.selectors.string(), ExprMetadata.selector_multi_unnamed() ) @@ -262,7 +262,7 @@ def categorical() -> Selector: """ return Selector( lambda plx: plx.selectors.categorical(), - ExprMetadata.multi_output_selector_unnamed(), + ExprMetadata.selector_multi_unnamed(), ) @@ -287,7 +287,7 @@ def all() -> Selector: 1 2 y True """ return Selector( - lambda plx: plx.selectors.all(), ExprMetadata.multi_output_selector_unnamed() + lambda plx: plx.selectors.all(), ExprMetadata.selector_multi_unnamed() ) @@ -348,7 +348,7 @@ def datetime( """ return Selector( lambda plx: plx.selectors.datetime(time_unit=time_unit, time_zone=time_zone), - ExprMetadata.multi_output_selector_unnamed(), + ExprMetadata.selector_multi_unnamed(), ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8dfb4838e5..01d3654630 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -997,7 +997,7 @@ def head(self: Self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + self._metadata.with_kind_and_closeable_window(ExprKind.FILTRATION), ) def tail(self: Self, n: int = 10) -> Self: @@ -1011,7 +1011,7 @@ def tail(self: Self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + self._metadata.with_kind_and_closeable_window(ExprKind.FILTRATION), ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -1026,7 +1026,7 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + self._metadata.with_kind_and_closeable_window(ExprKind.FILTRATION), ) def unique(self: Self, *, maintain_order: bool | None = None) -> Self: @@ -1065,7 +1065,7 @@ def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> S lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - self._metadata.with_extra_open_window(), + self._metadata.with_uncloseable_window(), ) def arg_true(self: Self) -> Self: @@ -1076,7 +1076,7 @@ def arg_true(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - self._metadata.with_kind_and_extra_open_window(ExprKind.FILTRATION), + self._metadata.with_kind_and_closeable_window(ExprKind.FILTRATION), ) def sample( diff --git a/tests/expression_parsing_test.py b/tests/expression_parsing_test.py index 0f7e30674d..3644fbf7f7 100644 --- a/tests/expression_parsing_test.py +++ b/tests/expression_parsing_test.py @@ -3,25 +3,52 @@ import pytest import narwhals as nw +from narwhals._expression_parsing import WindowKind from narwhals.exceptions import InvalidOperationError @pytest.mark.parametrize( ("expr", "expected"), [ - (nw.col("a"), 0), - (nw.col("a").mean(), 0), - (nw.col("a").cum_sum(), 1), - (nw.col("a").cum_sum().over(order_by="id"), 0), - ((nw.col("a").cum_sum() + 1).over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1), - (nw.col("a").cum_sum().cum_sum(), 2), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1), - (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over("a"), 1), + (nw.col("a"), WindowKind.NONE), + (nw.col("a").mean(), WindowKind.NONE), + (nw.col("a").cum_sum(), WindowKind.CLOSEABLE), + (nw.col("a").cum_sum().over(order_by="id"), WindowKind.CLOSED), + (nw.col("a").cum_sum().abs().over(order_by="id"), WindowKind.UNCLOSEABLE), + ((nw.col("a").cum_sum() + 1).over(order_by="id"), WindowKind.UNCLOSEABLE), + (nw.col("a").cum_sum().cum_sum().over(order_by="id"), WindowKind.UNCLOSEABLE), + (nw.col("a").cum_sum().cum_sum(), WindowKind.UNCLOSEABLE), + (nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), WindowKind.UNCLOSEABLE), + ( + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over("a"), + WindowKind.UNCLOSEABLE, + ), + ( + nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), + WindowKind.CLOSED, + ), + ( + nw.sum_horizontal( + nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i") + ), + WindowKind.UNCLOSEABLE, + ), + ( + nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( + order_by="i" + ), + WindowKind.UNCLOSEABLE, + ), + ( + nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( + order_by="i" + ), + WindowKind.UNCLOSEABLE, + ), ], ) -def test_has_open_windows(expr: nw.Expr, expected: int) -> None: - assert expr._metadata.n_open_windows == expected +def test_window_kind(expr: nw.Expr, expected: WindowKind) -> None: + assert expr._metadata.window_kind is expected def test_misleading_order_by() -> None: @@ -29,3 +56,20 @@ def test_misleading_order_by() -> None: nw.col("a").mean().over(order_by="b") with pytest.raises(InvalidOperationError): nw.col("a").rank().over(order_by="b") + + +def test_double_over() -> None: + with pytest.raises(InvalidOperationError): + nw.col("a").mean().over("b").over("c") + + +def test_double_agg() -> None: + with pytest.raises(InvalidOperationError): + nw.col("a").mean().mean() + with pytest.raises(InvalidOperationError): + nw.col("a").mean().arg_max() + + +def test_filter_aggregation() -> None: + with pytest.raises(InvalidOperationError): + nw.col("a").mean().drop_nulls() diff --git a/tests/group_by_test.py b/tests/group_by_test.py index d73a77dd17..27dde67ae0 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -55,7 +55,7 @@ def test_invalid_group_by_dask() -> None: df_dask = dd.from_pandas(df_pandas) with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"): - nw.from_native(df_dask).group_by("a").agg(nw.col("b").mean().min()) + nw.from_native(df_dask).group_by("a").agg(nw.col("b").abs().min()) def test_group_by_iter(constructor_eager: ConstructorEager) -> None: diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 30def36601..6f97d26b3c 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -162,7 +162,7 @@ # Expr methods expr_methods = [ i - for i in nw.Expr(lambda: 0, ExprMetadata.simple_selector()).__dir__() + for i in nw.Expr(lambda: 0, ExprMetadata.selector_single()).__dir__() if not i[0].isupper() and i[0] != "_" ] with open("docs/api-reference/expr.md") as fd: @@ -186,7 +186,7 @@ expr_methods = [ i for i in getattr( - nw.Expr(lambda: 0, ExprMetadata.simple_selector()), + nw.Expr(lambda: 0, ExprMetadata.selector_single()), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" @@ -230,7 +230,7 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr(lambda: 0, ExprMetadata.simple_selector()).__dir__() + for i in nw.Expr(lambda: 0, ExprMetadata.selector_single()).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -252,7 +252,7 @@ expr_internal = [ i for i in getattr( - nw.Expr(lambda: 0, ExprMetadata.simple_selector()), + nw.Expr(lambda: 0, ExprMetadata.selector_single()), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_"