diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 392093fc16..198778fb06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -106,7 +106,8 @@ repos: narwhals/.*typing\.py| narwhals/_plan/functions\.py| narwhals/_plan/expressions/ranges\.py| - narwhals/_plan/schema\.py + narwhals/_plan/schema\.py| + narwhals/_plan/expressions/selectors\.py ) - id: pull-request-target name: don't use `pull_request_target` diff --git a/narwhals/_plan/__init__.py b/narwhals/_plan/__init__.py index afeff442c0..c40d064a95 100644 --- a/narwhals/_plan/__init__.py +++ b/narwhals/_plan/__init__.py @@ -1,8 +1,8 @@ from __future__ import annotations +from narwhals._plan import selectors from narwhals._plan.dataframe import DataFrame -from narwhals._plan.expr import Expr, Selector -from narwhals._plan.expressions import selectors +from narwhals._plan.expr import Expr from narwhals._plan.functions import ( all, all_horizontal, @@ -25,6 +25,7 @@ sum_horizontal, when, ) +from narwhals._plan.selectors import Selector from narwhals._plan.series import Series __all__ = [ diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 20a886c3a3..abfeb6c1b0 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -35,338 +35,318 @@ [6e57eff4f059c748cf84ddcae276a74318720b85]: https://github.com/narwhals-dev/narwhals/commit/6e57eff4f059c748cf84ddcae276a74318720b85 """ -# ruff: noqa: A002 from __future__ import annotations from collections import deque -from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Union -from narwhals._plan import common, meta -from narwhals._plan._guards import is_horizontal_reduction -from narwhals._plan._immutable import Immutable +from narwhals._plan import common, expressions as ir, meta from narwhals._plan.exceptions import ( - column_index_error, + binary_expr_multi_output_error, column_not_found_error, duplicate_error, ) from narwhals._plan.expressions import ( Alias, - All, - Columns, - Exclude, ExprIR, - IndexColumns, KeepName, NamedIR, - Nth, RenameAlias, SelectorIR, - _ColumnSelection, - col, - cols, ) -from narwhals._plan.schema import ( - FrozenColumns, - FrozenSchema, - IntoFrozenSchema, - freeze_schema, -) -from narwhals.dtypes import DType -from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals._plan.schema import FrozenSchema, IntoFrozenSchema, freeze_schema +from narwhals._typing_compat import assert_never +from narwhals._utils import check_column_names_are_unique, zip_strict +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Collection, Iterable, Iterator, Sequence from typing_extensions import TypeAlias - from narwhals._plan.typing import Seq - from narwhals.dtypes import DType - - -Excluded: TypeAlias = "frozenset[str]" -"""Internally use a `set`, then freeze before returning.""" - -GroupByKeys: TypeAlias = "Seq[str]" -"""Represents `group_by` keys. + from narwhals._plan.typing import Ignored, Seq -They need to be excluded from expansion. -""" OutputNames: TypeAlias = "Seq[str]" """Fully expanded, validated output column names, for `NamedIR`s.""" -class ExpansionFlags(Immutable): - """`polars` uses a struct, but we may want to use `enum.Flag`.""" - - __slots__ = ( - "has_exclude", - "has_nth", - "has_selector", - "has_wildcard", - "multiple_columns", - ) - multiple_columns: bool - has_nth: bool - has_wildcard: bool - has_selector: bool - has_exclude: bool - - @property - def expands(self) -> bool: - """If we add struct stuff, that would slot in here as well.""" - return self.multiple_columns - - @staticmethod - def from_ir(ir: ExprIR, /) -> ExpansionFlags: - """Subset of [`find_flags`]. - - [`find_flags`]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L607-L660 - """ - multiple_columns: bool = False - has_nth: bool = False - has_wildcard: bool = False - has_selector: bool = False - has_exclude: bool = False - for e in ir.iter_left(): - if isinstance(e, (_ColumnSelection, SelectorIR)): - if isinstance(e, (Columns, IndexColumns)): - multiple_columns = True - elif isinstance(e, Nth): - has_nth = True - elif isinstance(e, All): - has_wildcard = True - elif isinstance(e, SelectorIR): - has_selector = True - elif isinstance(e, Exclude): - has_exclude = True - return ExpansionFlags( - multiple_columns=multiple_columns, - has_nth=has_nth, - has_wildcard=has_wildcard, - has_selector=has_selector, - has_exclude=has_exclude, - ) - - def with_multiple_columns(self) -> ExpansionFlags: - return common.replace(self, multiple_columns=True) +Combination: TypeAlias = Union[ + ir.SortBy, + ir.BinaryExpr, + ir.TernaryExpr, + ir.Filter, + ir.OrderedWindowExpr, + ir.WindowExpr, +] def prepare_projection( - exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema + exprs: Sequence[ExprIR], /, ignored: Ignored = (), *, schema: IntoFrozenSchema ) -> tuple[Seq[NamedIR], FrozenSchema]: - """Expand IRs into named column selections. + """Expand IRs into named column projections. **Primary entry-point**, for `select`, `with_columns`, and any other context that requires resolving expression names. Arguments: - exprs: IRs that *may* contain things like `Columns`, `SelectorIR`, `Exclude`, etc. - keys: Names of `group_by` columns. - schema: Scope to expand multi-column selectors in. + exprs: IRs that *may* contain arbitrarily nested expressions. + ignored: Names of `group_by` columns. + schema: Scope to expand selectors in. """ - frozen_schema = freeze_schema(schema) - rewritten = rewrite_projections(tuple(exprs), keys=keys, schema=frozen_schema) - output_names = ensure_valid_exprs(rewritten, frozen_schema) - named_irs = into_named_irs(rewritten, output_names) - return named_irs, frozen_schema - - -def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: - if len(exprs) != len(names): - msg = f"zip length mismatch: {len(exprs)} != {len(names)}" - raise ValueError(msg) - return tuple( - NamedIR(expr=remove_alias(ir), name=name) for ir, name in zip(exprs, names) - ) + expander = Expander(schema, ignored) + return expander.prepare_projection(exprs), expander.schema -def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: - """Raise an appropriate error if we can't materialize.""" - output_names = _ensure_output_names_unique(exprs) - root_names = meta.root_names_unique(exprs) - if not (set(schema.names).issuperset(root_names)): - raise column_not_found_error(root_names, schema) - return output_names +def expand_selector_irs_names( + selectors: Sequence[SelectorIR], /, ignored: Ignored = (), *, schema: IntoFrozenSchema +) -> OutputNames: + """Expand selector-only input into the column names that match. + Similar to `prepare_projection`, but intended for allowing a subset of `Expr` and all `Selector`s + to be used in more places like `DataFrame.{drop,sort,partition_by}`. -def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: - names = tuple(e.meta.output_name() for e in exprs) + Arguments: + selectors: IRs that **only** contain subclasses of `SelectorIR`. + ignored: Names of `group_by` columns. + schema: Scope to expand selectors in. + """ + names = tuple(Expander(schema, ignored).iter_expand_selector_names(selectors)) if len(names) != len(set(names)): - raise duplicate_error(exprs) + # NOTE: Can't easily reuse `duplicate_error`, falling back to main for now + check_column_names_are_unique(names) return names -def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - def fn(child: ExprIR, /) -> ExprIR: - if is_horizontal_reduction(child): - rewrites = rewrite_projections(child.input, schema=schema) - return common.replace(child, input=rewrites) - return child - - return origin.map_ir(fn) - - -def replace_nth(origin: ExprIR, /, col_names: FrozenColumns) -> ExprIR: - n_fields = len(col_names) - +def remove_alias(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Nth): - if not is_index_in_range(child.index, n_fields): - raise column_index_error(child.index, col_names) - return col(col_names[child.index]) - return child + return child.expr if isinstance(child, (Alias, RenameAlias)) else child return origin.map_ir(fn) -def is_index_in_range(index: int, n_fields: int) -> bool: - idx = index + n_fields if index < 0 else index - return not (idx < 0 or idx >= n_fields) - +def replace_keep_name(origin: ExprIR, /) -> ExprIR: + root_name = meta.root_name_first(origin) -def remove_alias(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: - return child.expr if isinstance(child, Alias) else child + return child.expr.alias(root_name) if isinstance(child, KeepName) else child return origin.map_ir(fn) -def replace_with_column( - origin: ExprIR, tp: type[_ColumnSelection], /, name: str -) -> ExprIR: - """Expand a single column within a multi-selection using `name`.""" +class Expander: + __slots__ = ("ignored", "schema") + schema: FrozenSchema + ignored: Ignored + + def __init__(self, scope: IntoFrozenSchema, ignored: Ignored = ()) -> None: + self.schema = freeze_schema(scope) + self.ignored = ignored + + def iter_expand_exprs(self, exprs: Iterable[ExprIR], /) -> Iterator[ExprIR]: + # Iteratively expand all of exprs + for expr in exprs: + yield from self._expand(expr) + + def iter_expand_selector_names( + self, selectors: Iterable[SelectorIR], / + ) -> Iterator[str]: + for s in selectors: + yield from s.iter_expand_names(self.schema, self.ignored) + + def prepare_projection(self, exprs: Collection[ExprIR], /) -> Seq[NamedIR]: + output_names = deque[str]() + named_irs = deque[NamedIR]() + root_names = set[str]() + + # NOTE: Collecting here isn't ideal (perf-wise), but the expanded `ExprIR`s + # have more useful information to add in an error message + # Another option could be keeping things lazy, but repeating the work for the error case? + # that way, there isn't a cost paid on the happy path - and it doesn't matter when we're raising + # if we take our time displaying the message + expanded = tuple(self.iter_expand_exprs(exprs)) + for e in expanded: + # NOTE: Empty string is allowed as a name, but is falsy + if (name := e.meta.output_name(raise_if_undetermined=False)) is not None: + target = e + elif meta.has_expr_ir(e, KeepName): + replaced = replace_keep_name(e) + name = replaced.meta.output_name() + target = replaced + else: + msg = f"Unable to determine output name for expression, got: `{e!r}`" + raise NotImplementedError(msg) + output_names.append(name) + named_irs.append(ir.named_ir(name, remove_alias(target))) + root_names.update(meta.iter_root_names(e)) + if len(output_names) != len(set(output_names)): + raise duplicate_error(expanded) + if not (set(self.schema).issuperset(root_names)): + raise column_not_found_error(root_names, self.schema) + return tuple(named_irs) + + def _expand(self, expr: ExprIR, /) -> Iterator[ExprIR]: + # For a single expr, fully expand all parts of it + if all(not e.needs_expansion() for e in expr.iter_left()): + yield expr + else: + yield from self._expand_recursive(expr) + + def _expand_recursive(self, origin: ExprIR, /) -> Iterator[ExprIR]: + # Dispatch the kind of expansion, based on the type of expr + # Every other method will call back here + # Based on https://github.com/pola-rs/polars/blob/5b90db75911c70010d0c0a6941046e6144af88d4/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs#L253-L850 + if isinstance(origin, _EXPAND_NONE): + yield origin + elif isinstance(origin, ir.SelectorIR): + names = origin.iter_expand_names(self.schema, self.ignored) + yield from (ir.Column(name=name) for name in names) + elif isinstance(origin, _EXPAND_SINGLE): + for expr in self._expand_recursive(origin.expr): + yield origin.__replace__(expr=expr) + elif isinstance(origin, _EXPAND_COMBINATION): + yield from self._expand_combination(origin) + elif isinstance(origin, ir.FunctionExpr): + yield from self._expand_function_expr(origin) + else: + msg = f"Didn't expect to see {type(origin).__name__}" + raise NotImplementedError(msg) + + def _expand_inner(self, children: Seq[ExprIR], /) -> Iterator[ExprIR]: + """Use when we want to expand non-root nodes, *without* duplicating the root. + + If we wrote: + + col("a").over(col("c", "d", "e")) + + Then the expanded version should be: + + col("a").over(col("c"), col("d"), col("e")) + + An **incorrect** output would cause an error without aliasing: + + col("a").over(col("c")) + col("a").over(col("d")) + col("a").over(col("e")) + + And cause an error if we needed to expand both sides: + + col("a", "b").over(col("c", "d", "e")) + + Since that would become: + + col("a").over(col("c")) + col("b").over(col("d")) + col().over(col("e")) # InvalidOperationError: cannot combine selectors that produce a different number of columns (3 != 2) + """ + # used by + # - `_expand_combination` (tuple fields) + # - `_expand_function_expr` (horizontal) + for child in children: + yield from self._expand_recursive(child) + + def _expand_only(self, child: ExprIR, /) -> ExprIR: + # used by + # - `_expand_combination` (ExprIR fields) + # - `_expand_function_expr` (all others that have len(inputs)>=2, call on non-root) + iterable = self._expand_recursive(child) + first = next(iterable) + if second := next(iterable, None): + msg = f"Multi-output expressions are not supported in this context, got: `{second!r}`" # pragma: no cover + raise MultiOutputExpressionError(msg) # pragma: no cover + return first + + # TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves + def _expand_combination(self, origin: Combination, /) -> Iterator[Combination]: + changes: dict[str, Any] = {} + if isinstance(origin, (ir.WindowExpr, ir.Filter, ir.SortBy)): + if isinstance(origin, ir.WindowExpr): + if partition_by := origin.partition_by: + changes["partition_by"] = tuple(self._expand_inner(partition_by)) + if isinstance(origin, ir.OrderedWindowExpr): + changes["order_by"] = tuple(self._expand_inner(origin.order_by)) + elif isinstance(origin, ir.SortBy): + changes["by"] = tuple(self._expand_inner(origin.by)) + else: + changes["by"] = self._expand_only(origin.by) + replaced = common.replace(origin, **changes) + for root in self._expand_recursive(replaced.expr): + yield common.replace(replaced, expr=root) + elif isinstance(origin, ir.BinaryExpr): + yield from self._expand_binary_expr(origin) + elif isinstance(origin, ir.TernaryExpr): + changes["truthy"] = self._expand_only(origin.truthy) + changes["predicate"] = self._expand_only(origin.predicate) + changes["falsy"] = self._expand_only(origin.falsy) + yield origin.__replace__(**changes) + else: + assert_never(origin) + + def _expand_binary_expr(self, origin: ir.BinaryExpr, /) -> Iterator[ir.BinaryExpr]: + it_lefts = self._expand_recursive(origin.left) + it_rights = self._expand_recursive(origin.right) + # NOTE: Fast-path that doesn't require collection + # - Will miss selectors that expand to 1 column + if not origin.meta.has_multiple_outputs(): + for left, right in zip_strict(it_lefts, it_rights): + yield origin.__replace__(left=left, right=right) + return + # NOTE: Covers 1:1 (where either is a selector), N:N + lefts, rights = tuple(it_lefts), tuple(it_rights) + len_left, len_right = len(lefts), len(rights) + if len_left == len_right: + for left, right in zip_strict(lefts, rights): + yield origin.__replace__(left=left, right=right) + # NOTE: 1:M + elif len_left == 1: + binary = origin.__replace__(left=lefts[0]) + yield from (binary.__replace__(right=right) for right in rights) + # NOTE: M:1 + elif len_right == 1: + binary = origin.__replace__(right=rights[0]) + yield from (binary.__replace__(left=left) for left in lefts) + else: + raise binary_expr_multi_output_error(origin, lefts, rights) + + def _expand_function_expr( + self, origin: ir.FunctionExpr, / + ) -> Iterator[ir.FunctionExpr]: + if origin.options.is_input_wildcard_expansion(): + reduced = tuple(self._expand_inner(origin.input)) + yield origin.__replace__(input=reduced) + else: + if non_root := origin.input[1:]: + children = tuple(self._expand_only(child) for child in non_root) + else: + children = () + for root in self._expand_recursive(origin.input[0]): + yield origin.__replace__(input=(root, *children)) - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, tp): - return col(name) - return child.expr if isinstance(child, Exclude) else child - return origin.map_ir(fn) +_EXPAND_NONE = (ir.Column, ir.Literal, ir.Len) +"""we're at the root, nothing left to expand.""" +_EXPAND_SINGLE = (ir.Alias, ir.Cast, ir.AggExpr, ir.Sort, ir.KeepName, ir.RenameAlias) +"""one (direct) child, always stored in `self.expr`. +An expansion will always just be cloning *everything but* `self.expr`, +we only need to be concerned with a **single** attribute. -def replace_selector(ir: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - def fn(child: ExprIR, /) -> ExprIR: - return expand_selector(child, schema) if isinstance(child, SelectorIR) else child +Say we had: - return ir.map_ir(fn) + origin = Cast(expr=ByName(names=("one", "two"), require_all=True), dtype=String) +This would expand to: -@lru_cache(maxsize=100) -def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) -> bool: - """Cached version of `SelectorIR.matches.column`. + cast_one = Cast(expr=Column(name="one"), dtype=String) + cast_two = Cast(expr=Column(name="two"), dtype=String) - Allows results of evaluations can be shared across: - - Instances of `SelectorIR` - - Multiple schemas - """ - return selector.matches_column(name, dtype) - - -@lru_cache(maxsize=100) -def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: - """Expand `selector` into `Columns`, within the context of `schema`.""" - matches = selector_matches_column - return cols(*(k for k, v in schema.items() if matches(selector, k, v))) - - -def rewrite_projections( - input: Seq[ExprIR], # `FunctionExpr.input` - /, - keys: GroupByKeys = (), - *, - schema: FrozenSchema, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - for expr in input: - expanded = expand_function_inputs(expr, schema=schema) - flags = ExpansionFlags.from_ir(expanded) - if flags.has_selector: - expanded = replace_selector(expanded, schema=schema) - flags = flags.with_multiple_columns() - result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags)) - return tuple(result) - - -def iter_replace( - origin: ExprIR, - /, - keys: GroupByKeys, - *, - col_names: FrozenColumns, - flags: ExpansionFlags, -) -> Iterator[ExprIR]: - if flags.has_nth: - origin = replace_nth(origin, col_names) - if flags.expands: - it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) - if e := next(it, None): - if isinstance(e, Columns): - if not _all_columns_match(origin, e): - msg = "expanding more than one `col` is not allowed" - raise ComputeError(msg) - names: Iterable[str] = e.names - else: - names = _iter_index_names(e, col_names) - exclude = prepare_excluded(origin, keys, flags) - yield from expand_column_selection(origin, type(e), names, exclude) - elif flags.has_wildcard: - exclude = prepare_excluded(origin, keys, flags) - yield from expand_column_selection(origin, All, col_names, exclude) - else: - yield rewrite_special_aliases(origin) - - -def prepare_excluded( - origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / -) -> Excluded: - """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" - gb_keys = frozenset(keys) - if not flags.has_exclude: - return gb_keys - return gb_keys.union(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) - - -def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: - it = (e == columns if isinstance(e, Columns) else True for e in origin.iter_left()) - return all(it) - - -def _iter_index_names(indices: IndexColumns, names: FrozenColumns, /) -> Iterator[str]: - n_fields = len(names) - for index in indices.indices: - if not is_index_in_range(index, n_fields): - raise column_index_error(index, names) - yield names[index] - - -def expand_column_selection( - origin: ExprIR, tp: type[_ColumnSelection], /, names: Iterable[str], exclude: Excluded -) -> Iterator[ExprIR]: - for name in names: - if name not in exclude: - yield rewrite_special_aliases(replace_with_column(origin, tp, name)) - - -def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: - """Expand `KeepName` and `RenameAlias` into `Alias`. - - Warning: - Only valid **after** - - Expanding all selections into `Column` - - Dealing with `FunctionExpr.input` - """ - if meta.has_expr_ir(origin, KeepName, RenameAlias): - if isinstance(origin, KeepName): - parent = origin.expr - return parent.alias(next(iter(parent.meta.root_names()))) - if isinstance(origin, RenameAlias): - parent = origin.expr - leaf_name_or_err = meta.get_single_leaf_name(parent) - if not isinstance(leaf_name_or_err, str): - raise leaf_name_or_err - return parent.alias(origin.function(leaf_name_or_err)) - msg = "`keep`, `suffix`, `prefix` should be last expression" - raise InvalidOperationError(msg) - return origin +""" +_EXPAND_COMBINATION = ( + ir.SortBy, + ir.BinaryExpr, + ir.TernaryExpr, + ir.Filter, + ir.OrderedWindowExpr, + ir.WindowExpr, +) +"""more than one (direct) child and those can be nested.""" diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 8f66f2d9fa..80b6a0a8d4 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -7,7 +7,8 @@ from narwhals._plan._immutable import Immutable from narwhals._plan.common import replace from narwhals._plan.options import ExprIROptions -from narwhals._plan.typing import ExprIRT +from narwhals._plan.typing import ExprIRT, Ignored +from narwhals.exceptions import InvalidOperationError from narwhals.utils import Version if TYPE_CHECKING: @@ -17,11 +18,14 @@ from typing_extensions import Self from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co - from narwhals._plan.expr import Expr, Selector + from narwhals._plan.expr import Expr from narwhals._plan.expressions.expr import Alias, Cast, Column from narwhals._plan.meta import MetaNamespace + from narwhals._plan.schema import FrozenSchema + from narwhals._plan.selectors import Selector from narwhals._plan.typing import ExprIRT2, MapIR, Seq from narwhals.dtypes import DType + from narwhals.typing import IntoDType class ExprIR(Immutable): @@ -59,10 +63,17 @@ def to_narwhals(self, version: Version = Version.MAIN) -> Expr: tp = expr.Expr if version is Version.MAIN else expr.ExprV1 return tp._from_ir(self) + def to_selector_ir(self) -> SelectorIR: + msg = f"cannot turn `{self!r}` into a selector" + raise InvalidOperationError(msg) + @property def is_scalar(self) -> bool: return False + def needs_expansion(self) -> bool: + return any(isinstance(e, SelectorIR) for e in self.iter_left()) + def map_ir(self, function: MapIR, /) -> ExprIR: """Apply `function` to each child node, returning a new `ExprIR`. @@ -141,7 +152,7 @@ def iter_right(self) -> Iterator[ExprIR]: if isinstance(child, ExprIR): yield from child.iter_right() else: - for node in reversed(child): + for node in reversed(child): # pragma: no cover yield from node.iter_right() def iter_root_names(self) -> Iterator[ExprIR]: @@ -186,20 +197,49 @@ def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprI class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: - from narwhals._plan import expr + from narwhals._plan.selectors import Selector, SelectorV1 + + tp = Selector if version is Version.MAIN else SelectorV1 + return tp._from_ir(self) + + # NOTE: Corresponds with `Selector.iter_expand` + # A longer name is used here to distinguish expression and name-only expansion + def iter_expand_names( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + """Yield column names that match the selector, in `schema` order[^1]. + + Adapted from [upstream]. + + Arguments: + schema: Target scope to expand the selector in. + ignored_columns: Names of `group_by` columns, which are excluded[^2] from the result. - if version is Version.MAIN: - return expr.Selector._from_ir(self) - return expr.SelectorV1._from_ir(self) + Note: + [^1]: `ByName`, `ByIndex` return their inputs in given order not in schema order. - def matches_column(self, name: str, dtype: DType) -> bool: - """Return True if we can select this column. + Note: + [^2]: `ByName`, `ByIndex` will never be ignored. - - Thinking that we could get more cache hits on an individual column basis. - - May also be more efficient to not iterate over the schema for every selector - - Instead do one pass, evaluating every selector against a single column at a time + [upstream]: https://github.com/pola-rs/polars/blob/2b241543851800595efd343be016b65cdbdd3c9f/crates/polars-plan/src/dsl/selector.rs#L188-L198 """ - raise NotImplementedError(type(self)) + msg = f"{type(self).__name__}.iter_expand_names" + raise NotImplementedError(msg) + + def matches(self, dtype: IntoDType) -> bool: + """Return True if we can select this dtype.""" + msg = f"{type(self).__name__}.matches" + raise NotImplementedError(msg) + + def to_dtype_selector(self) -> Self: + msg = f"{type(self).__name__}.to_dtype_selector" + raise NotImplementedError(msg) + + def to_selector_ir(self) -> Self: + return self + + def needs_expansion(self) -> bool: + return True class NamedIR(Immutable, Generic[ExprIRT]): @@ -244,10 +284,10 @@ def map_ir(self, function: MapIR, /) -> Self: def __repr__(self) -> str: return f"{self.name}={self.expr!r}" - def _repr_html_(self) -> str: + def _repr_html_(self) -> str: # pragma: no cover return f"{self.name}={self.expr._repr_html_()}" - def is_elementwise_top_level(self) -> bool: + def is_elementwise_top_level(self) -> bool: # pragma: no cover """Return True if the outermost node is elementwise. Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 780070f038..da6b84880d 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -3,10 +3,12 @@ from __future__ import annotations import datetime as dt +import re # `_utils` imports at module-level from decimal import Decimal from typing import TYPE_CHECKING, Any, TypeVar from narwhals._utils import _hasattr_static +from narwhals.dtypes import DType if TYPE_CHECKING: from typing_extensions import TypeIs @@ -14,8 +16,14 @@ from narwhals._plan import expressions as ir from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.expr import Expr + from narwhals._plan.selectors import Selector from narwhals._plan.series import Series - from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq + from narwhals._plan.typing import ( + ColumnNameOrSelector, + IntoExprColumn, + NativeSeriesT, + Seq, + ) from narwhals.typing import NonNestedLiteral T = TypeVar("T") @@ -44,6 +52,12 @@ def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 return expr +def _selectors(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import selectors + + return selectors + + def _series(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 from narwhals._plan import series @@ -58,6 +72,10 @@ def is_expr(obj: Any) -> TypeIs[Expr]: return isinstance(obj, _expr().Expr) +def is_selector(obj: Any) -> TypeIs[Selector]: + return isinstance(obj, _selectors().Selector) + + def is_column(obj: Any) -> TypeIs[Expr]: """Indicate if the given object is a basic/unaliased column.""" return is_expr(obj) and obj.meta.is_column() @@ -71,6 +89,13 @@ def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]: return isinstance(obj, (str, _expr().Expr, _series().Series)) +def is_column_name_or_selector( + obj: Any, *, allow_expr: bool = False +) -> TypeIs[ColumnNameOrSelector]: + tps = (str, _selectors().Selector) if not allow_expr else (str, _expr().Expr) + return isinstance(obj, tps) + + def is_compliant_series( obj: CompliantSeries[NativeSeriesT] | Any, ) -> TypeIs[CompliantSeries[NativeSeriesT]]: @@ -78,7 +103,9 @@ def is_compliant_series( def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: - return isinstance(obj, (str, bytes, _series().Series)) or is_compliant_series(obj) + return isinstance(obj, (str, bytes, _series().Series, DType)) or is_compliant_series( + obj + ) def is_window_expr(obj: Any) -> TypeIs[ir.WindowExpr]: @@ -106,9 +133,11 @@ def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]: return isinstance(obj, _ir().Literal) -def is_horizontal_reduction(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: - return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() +# TODO @dangotbanned: Coverage +# Used in `ArrowNamespace._vertical`, but only horizontal is covered +def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: # pragma: no cover + return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) -def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: - return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) +def is_re_pattern(obj: Any) -> TypeIs[re.Pattern[str]]: + return isinstance(obj, re.Pattern) diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py index 09cc7fc90a..b64090f48b 100644 --- a/narwhals/_plan/_immutable.py +++ b/narwhals/_plan/_immutable.py @@ -31,6 +31,8 @@ class Immutable(metaclass=ImmutableMeta): # NOTE: Trying to avoid this being added to synthesized `__init__` # Seems to be the only difference when decorating the metaclass __immutable_hash_value__: int + else: # pragma: no cover + ... __immutable_keys__: ClassVar[tuple[str, ...]] @@ -108,7 +110,9 @@ def __init__(self, **kwds: Any) -> None: def _field_str(name: str, value: Any) -> str: if isinstance(value, tuple): - inner = ", ".join(f"{v}" for v in value) + inner = ", ".join( + (f"{v!s}" if not isinstance(v, str) else f"{v!r}") for v in value + ) return f"{name}=[{inner}]" if isinstance(value, str): return f"{name}={value!r}" diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index c2a5cc7c2f..5b23cefde4 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -1,29 +1,39 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +import operator +from collections import deque +from collections.abc import Collection, Iterable, Sequence # ruff: noqa: A002 +from functools import reduce from itertools import chain from typing import TYPE_CHECKING -from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject -from narwhals._plan.exceptions import ( - invalid_into_expr_error, - is_iterable_pandas_error, - is_iterable_polars_error, +from narwhals._native import is_native_pandas +from narwhals._plan._guards import ( + is_column_name_or_selector, + is_expr, + is_into_expr_column, + is_iterable_reject, + is_selector, ) -from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series +from narwhals._plan.common import flatten_hash_safe +from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error +from narwhals._utils import qualified_type_name +from narwhals.dependencies import get_polars from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Iterator from typing import Any, TypeVar - import polars as pl from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.expressions import ExprIR + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, SelectorIR + from narwhals._plan.selectors import Selector from narwhals._plan.typing import ( + ColumnNameOrSelector, IntoExpr, IntoExprColumn, OneOrIterable, @@ -117,13 +127,64 @@ def parse_into_expr_ir( expr = col(input) elif isinstance(input, list): if list_as_series is None: - raise TypeError(input) + raise TypeError(input) # pragma: no cover expr = lit(list_as_series(input)) else: expr = lit(input, dtype=dtype) return expr._ir +def parse_into_selector_ir( + input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True +) -> SelectorIR: + return _parse_into_selector(input, require_all=require_all)._ir + + +def _parse_into_selector( + input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True +) -> Selector: + if is_selector(input): + selector = input + elif isinstance(input, str): + import narwhals._plan.selectors as cs + + selector = cs.by_name(input, require_all=require_all) + elif is_expr(input): + selector = input.meta.as_selector() + else: + msg = f"cannot turn {qualified_type_name(input)!r} into a selector" + raise TypeError(msg) + return selector + + +def parse_into_combined_selector_ir( + *inputs: OneOrIterable[ColumnNameOrSelector], require_all: bool = True +) -> SelectorIR: + import narwhals._plan.selectors as cs + + flat = tuple(flatten_hash_safe(inputs)) + selectors = deque["Selector"]() + if names := tuple(el for el in flat if isinstance(el, str)): + selector = cs.by_name(names, require_all=require_all) + if len(names) == len(flat): + return selector._ir + selectors.append(selector) + selectors.extend(_parse_into_selector(el) for el in flat if not isinstance(el, str)) + return _any_of(selectors)._ir + + +def _any_of(selectors: Collection[Selector], /) -> Selector: + import narwhals._plan.selectors as cs + + if not selectors: + s: Selector = cs.empty() + elif len(selectors) == 1: + s = next(iter(selectors)) + else: + s = reduce(operator.or_, selectors) + return s + + def parse_into_seq_of_expr_ir( first_input: OneOrIterable[IntoExpr] = (), *more_inputs: IntoExpr | _RaisesInvalidIntoExprError, @@ -170,11 +231,39 @@ def _parse_sort_by_into_iter_expr_ir( ) -> Iterator[ExprIR]: for e in _parse_into_iter_expr_ir(by, *more_by): if e.is_scalar: - msg = f"All expressions sort keys must preserve length, but got:\n{e!r}" - raise InvalidOperationError(msg) + msg = f"All expressions sort keys must preserve length, but got:\n{e!r}" # pragma: no cover + raise InvalidOperationError(msg) # pragma: no cover yield e +def parse_into_seq_of_selector_ir( + first_input: OneOrIterable[ColumnNameOrSelector], *more_inputs: ColumnNameOrSelector +) -> Seq[SelectorIR]: + return tuple(_parse_into_iter_selector_ir(first_input, more_inputs)) + + +def _parse_into_iter_selector_ir( + first_input: OneOrIterable[ColumnNameOrSelector], + more_inputs: tuple[ColumnNameOrSelector, ...], + /, +) -> Iterator[SelectorIR]: + if is_column_name_or_selector(first_input) and not more_inputs: + yield parse_into_selector_ir(first_input) + return + + if not _is_empty_sequence(first_input): + if _is_iterable(first_input) and not isinstance(first_input, str): + if more_inputs: # pragma: no cover + raise invalid_into_expr_error(first_input, more_inputs, {}) + else: + for into in first_input: # type: ignore[var-annotated] + yield parse_into_selector_ir(into) + else: + yield parse_into_selector_ir(first_input) + for into in more_inputs: # pragma: no cover + yield parse_into_selector_ir(into) + + def _parse_into_iter_expr_ir( first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr | list[Any], @@ -236,23 +325,23 @@ def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: msg = "at least one predicate or constraint must be provided" raise TypeError(msg) if second := next(predicates, None): - return AllHorizontal().to_function_expr(first, second, *predicates) - return first + inputs = first, second, *predicates + elif first.meta.has_multiple_outputs(): + # NOTE: Safeguarding against https://github.com/pola-rs/polars/issues/25022 + inputs = (first,) + else: + return first + return AllHorizontal().to_function_expr(*inputs) def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: - if is_pandas_dataframe(obj) or is_pandas_series(obj): - raise is_iterable_pandas_error(obj) - if _is_polars(obj): - raise is_iterable_polars_error(obj) + if is_native_pandas(obj) or ( + (pl := get_polars()) + and isinstance(obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame)) + ): + raise is_iterable_error(obj) return isinstance(obj, Iterable) def _is_empty_sequence(obj: Any) -> bool: return isinstance(obj, Sequence) and not obj - - -def _is_polars(obj: Any) -> TypeIs[pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame]: - return (pl := get_polars()) is not None and isinstance( - obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame) - ) diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index fd26364e66..1a4d40fb7d 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -88,7 +88,7 @@ def map_ir( origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR ) -> NamedOrExprIRT: """Apply one or more functions, sequentially, to all of `origin`'s children.""" - if more_functions: + if more_functions: # pragma: no cover result = origin for fn in (function, *more_functions): result = result.map_ir(fn) diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index f99fad5289..990a688d98 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -158,14 +158,6 @@ def _aggregate(aggs: Iterable[AggSpec], /, keys: Iterable[Field] | None = None) return Decl("aggregate", pac.AggregateNodeOptions(aggs_, keys=keys_)) -def aggregate(aggs: Iterable[AggSpec], /) -> Decl: - """May only use [Scalar aggregate] functions. - - [Scalar aggregate]: https://arrow.apache.org/docs/cpp/compute.html#aggregations - """ - return _aggregate(aggs) - - def group_by(keys: Iterable[Field], aggs: Iterable[AggSpec], /) -> Decl: """May only use [Hash aggregate] functions, requires grouping. diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 15b9dc80c0..d7ea2f6639 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -11,13 +11,12 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import acero, functions as fn from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar -from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy +from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace from narwhals._plan.expressions import NamedIR -from narwhals._plan.typing import Seq -from narwhals._utils import Implementation, Version, parse_columns_to_drop +from narwhals._utils import Implementation, Version from narwhals.schema import Schema if TYPE_CHECKING: @@ -29,7 +28,7 @@ from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.expressions import ExprIR, NamedIR from narwhals._plan.options import SortMultipleOptions - from narwhals._plan.typing import NonCrossJoinStrategy, Seq + from narwhals._plan.typing import NonCrossJoinStrategy from narwhals.dtypes import DType from narwhals.typing import IntoSchema @@ -92,10 +91,10 @@ def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series] from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) - def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: - df_by = self.select(by) - indices = pc.sort_indices(df_by.native, options=options.to_arrow(df_by.columns)) - return self._with_native(self.native.take(indices)) + def sort(self, by: Sequence[str], options: SortMultipleOptions) -> Self: + native = self.native + indices = pc.sort_indices(native.select(list(by)), options=options.to_arrow(by)) + return self._with_native(native.take(indices)) def with_row_index(self, name: str) -> Self: return self._with_native(self.native.add_column(0, name, fn.int_range(len(self)))) @@ -104,9 +103,8 @@ def get_column(self, name: str) -> Series: chunked = self.native.column(name) return Series.from_native(chunked, name, version=self.version) - def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: - to_drop = parse_columns_to_drop(self, columns, strict=strict) - return self._with_native(self.native.drop(to_drop)) + def drop(self, columns: Sequence[str]) -> Self: + return self._with_native(self.native.drop(list(columns))) def drop_nulls(self, subset: Sequence[str] | None) -> Self: if subset is None: @@ -152,7 +150,7 @@ def join( *, how: NonCrossJoinStrategy, left_on: Sequence[str], - right_on: Sequence[str], + right_on: Sequence[str] = (), suffix: str = "_right", ) -> Self: left, right = self.native, other.native @@ -171,3 +169,8 @@ def filter(self, predicate: NamedIR) -> Self: else: mask = acero.lit(resolved.native) return self._with_native(self.native.filter(mask)) + + def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]: + from_native = self._with_native + partitions = partition_by(self.native, by, include_key=include_key) + return [from_native(df) for df in partitions] diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index fb2bad1479..b34cf963ca 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -15,7 +15,7 @@ from narwhals._plan.compliant.expr import EagerExpr from narwhals._plan.compliant.scalar import EagerScalar from narwhals._plan.compliant.typing import namespace -from narwhals._plan.expressions import NamedIR +from narwhals._plan.expressions.functions import NullCount from narwhals._utils import Implementation, Version, _StoresNative, not_implemented from narwhals.exceptions import InvalidOperationError, ShapeError @@ -55,7 +55,15 @@ Not, ) from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr - from narwhals._plan.expressions.functions import Abs, FillNull, Pow + from narwhals._plan.expressions.functions import ( + Abs, + CumAgg, + Diff, + FillNull, + NullCount, + Pow, + Shift, + ) from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -322,13 +330,31 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: return self._with_native(result, name) # TODO @dangotbanned: top-level, complex-ish nodes - # - [ ] `over`/`_ordered` (with partitions) requires `group_by`, `join` - # - [x] `over_ordered` alone should be possible w/ the current API - # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main + # - [ ] Over + # - [x] `over_ordered` + # - [x] `group_by`, `join` + # - [x] `over` (with partitions) + # - [ ] `over_ordered` (with partitions) + # - [ ] `map_batches` + # - [x] elementwise + # - [ ] scalar # - [ ] `rolling_expr` has 4 variants def over(self, node: ir.WindowExpr, frame: Frame, name: str) -> Self: - raise NotImplementedError + resolved = ( + frame._grouper.by_irs(*node.partition_by) + # TODO @dangotbanned: Clean this up so the re-alias isn't needed + .agg_irs(node.expr.alias(name)) + .resolve(frame) + ) + by_names = resolved.key_names + result = ( + frame.select_names(*by_names) + .join(resolved.evaluate(frame), how="left", left_on=by_names) + .get_column(name) + .native + ) + return self._with_native(result, name) def over_ordered( self, node: ir.OrderedWindowExpr, frame: Frame, name: str @@ -338,8 +364,8 @@ def over_ordered( raise NotImplementedError(msg) # NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort` - sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by) - options = node.sort_options.to_multiple(len(node.order_by)) + sort_by = tuple(node.order_by_names()) + options = node.sort_options.to_multiple(len(sort_by)) idx_name = temp.column_name(frame) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) @@ -372,6 +398,24 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError + def shift(self, node: ir.FunctionExpr[Shift], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.shift(series.native, node.function.n), name) + + def diff(self, node: ir.FunctionExpr[Diff], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.diff(series.native), name) + + def _cumulative(self, node: ir.FunctionExpr[CumAgg], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.cumulative(series.native, node.function), name) + + cum_count = _cumulative + cum_min = _cumulative + cum_max = _cumulative + cum_prod = _cumulative + cum_sum = _cumulative + def _is_first_last_distinct( self, node: FunctionExpr[IsFirstDistinct | IsLastDistinct], @@ -393,6 +437,12 @@ def _is_first_last_distinct( is_first_distinct = _is_first_last_distinct is_last_distinct = _is_first_last_distinct + def null_count( + self, node: ir.FunctionExpr[NullCount], frame: Frame, name: str + ) -> Scalar: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.lit(series.native.null_count), name) + class ArrowScalar( _ArrowDispatch["ArrowScalar"], @@ -475,8 +525,21 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) + def null_count( + self, node: ir.FunctionExpr[NullCount], frame: Frame, name: str + ) -> Self: + native = node.input[0].dispatch(self, frame, name).native + return self._with_native(pa.scalar(0 if native.is_valid else 1), name) + filter = not_implemented() over = not_implemented() over_ordered = not_implemented() map_batches = not_implemented() + # length_preserving rolling_expr = not_implemented() + diff = not_implemented() + cum_sum = not_implemented() # TODO @dangotbanned: is this just self? + cum_count = not_implemented() + cum_min = not_implemented() + cum_max = not_implemented() + cum_prod = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 53b56d19b8..5df93c5a84 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -4,7 +4,7 @@ import typing as t from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -16,7 +16,7 @@ ) from narwhals._plan import expressions as ir from narwhals._plan.arrow import options -from narwhals._plan.expressions import operators as ops +from narwhals._plan.expressions import functions as F, operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: @@ -37,6 +37,7 @@ ChunkedArray, ChunkedArrayAny, ChunkedOrArrayAny, + ChunkedOrArrayT, ChunkedOrScalar, ChunkedOrScalarAny, DataType, @@ -53,7 +54,7 @@ StringType, UnaryFunction, ) - from narwhals.typing import ClosedInterval, IntoArrowSchema + from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -91,19 +92,24 @@ def modulus(lhs: Any, rhs: Any) -> Any: return sub(lhs, multiply(floor_div, rhs)) +# TODO @dangotbanned: Somehow fix the typing on this +# - `_ArrowDispatch` is relying on the gradual typing _DISPATCH_BINARY: Mapping[type[ops.Operator], BinOp] = { + # BinaryComp ops.Eq: eq, ops.NotEq: not_eq, ops.Lt: lt, ops.LtEq: lt_eq, ops.Gt: gt, ops.GtEq: gt_eq, - ops.Add: add, - ops.Sub: sub, - ops.Multiply: multiply, - ops.TrueDivide: truediv, - ops.FloorDivide: floordiv, - ops.Modulus: modulus, + # BinaryFunction (well it should be) + ops.Add: add, # BinaryNumericTemporal + ops.Sub: sub, # pyarrow-stubs + ops.Multiply: multiply, # pyarrow-stubs + ops.TrueDivide: truediv, # [[Any, Any], Any] + ops.FloorDivide: floordiv, # [[ArrayOrScalar, ArrayOrScalar], Any] + ops.Modulus: modulus, # [[Any, Any], Any] + # BinaryLogical ops.And: and_, ops.Or: or_, ops.ExclusiveOr: xor, @@ -208,6 +214,70 @@ def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") +def _reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Unlike other slicing ops, `[::-1]` creates a full-copy. + + https://github.com/apache/arrow/issues/19103#issuecomment-1377671886 + """ + return native[::-1] + + +def cumulative(native: ChunkedArrayAny, cum_agg: F.CumAgg, /) -> ChunkedArrayAny: + func = _CUMULATIVE[type(cum_agg)] + if not cum_agg.reverse: + return func(native) + return _reverse(func(_reverse(native))) + + +def cum_sum(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_sum(native, skip_nulls=True) + + +def cum_min(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_min(native, skip_nulls=True) + + +def cum_max(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_max(native, skip_nulls=True) + + +def cum_prod(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_prod(native, skip_nulls=True) + + +def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: + return cum_sum(is_not_null(native).cast(pa.uint32())) + + +_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { + F.CumSum: cum_sum, + F.CumCount: cum_count, + F.CumMin: cum_min, + F.CumMax: cum_max, + F.CumProd: cum_prod, +} + + +def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined + return ( + pc.pairwise_diff(native) + if isinstance(native, pa.Array) + else chunked_array(pc.pairwise_diff(native.combine_chunks())) + ) + + +def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny: + if n == 0: + return native + arr = native + if n > 0: + arrays = [nulls_like(n, arr), *arr.slice(length=arr.length() - n).chunks] + else: + arrays = [*arr.slice(offset=-n).chunks, nulls_like(-n, arr)] + return pa.chunked_array(arrays) + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], @@ -271,24 +341,45 @@ def int_range( return pa.chunked_array([pa.array(np.arange(start, end, step), dtype)]) +def nulls_like(n: int, native: ArrowAny) -> ArrayAny: + """Create a strongly-typed Array instance with all elements null. + + Uses the type of `native`. + """ + return pa.nulls(n, native.type) # type: ignore[no-any-return] + + def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) +@overload +def array(data: ArrowAny, /) -> ArrayAny: ... +@overload +def array( + data: Iterable[PythonLiteral], dtype: DataType | None = None, / +) -> ArrayAny: ... def array( - value: NativeScalar | Iterable[Any], dtype: DataType | None = None, / + data: ArrowAny | Iterable[PythonLiteral], dtype: DataType | None = None, / ) -> ArrayAny: - return ( - pa.array([value], value.type) - if isinstance(value, pa.Scalar) - else pa.array(value, dtype) - ) + """Convert `data` into an Array instance. + + Note: + `dtype` is not used for existing `pyarrow` data, use `cast` instead. + """ + if isinstance(data, pa.ChunkedArray): + return data.combine_chunks() + if isinstance(data, pa.Array): + return data + if isinstance(data, pa.Scalar): + return pa.array([data], data.type) + return pa.array(data, dtype) def chunked_array( - arr: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / + data: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / ) -> ChunkedArrayAny: - return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) + return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) def concat_vertical_chunked( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index df57f781c1..776f4d3d60 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -16,7 +16,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -52,6 +52,7 @@ ir.boolean.All: "hash_all", ir.boolean.Any: "hash_any", ir.functions.Unique: "hash_distinct", # `hash_aggregate` only + ir.functions.NullCount: "hash_count", } REQUIRES_PYARROW_20: tuple[Literal["kurtosis"], Literal["skew"]] = ("kurtosis", "skew") @@ -138,14 +139,6 @@ def group_by_error( return InvalidOperationError(msg) -def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: - dtype = fn.string_type(native.schema.types) - it = fn.cast_table(native, dtype).itercolumns() - concat: Incomplete = pc.binary_join_element_wise - join = options.join_replace_nulls() - return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] - - class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] @@ -157,15 +150,12 @@ def compliant(self) -> Frame: return self._df def __iter__(self) -> Iterator[tuple[Any, Frame]]: - temp_name = temp.column_name(self.compliant) - native = self.compliant.native - composite_values = concat_str(acero.select_names_table(native, self.key_names)) - re_keyed = native.add_column(0, temp_name, composite_values) + by = self.key_names from_native = self.compliant._with_native - for v in composite_values.unique(): - t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) + for partition in partition_by(self.compliant.native, by): + t = from_native(partition) yield ( - t.select_names(*self.key_names).row(0), + t.select_names(*by).row(0), t.select_names(*self._column_names_original), ) @@ -178,3 +168,59 @@ def agg(self, irs: Seq[NamedIR]) -> Frame: if original := self._key_names_original: return result.rename(dict(zip(key_names, original))) return result + + +def _composite_key(native: pa.Table, *, separator: str = "") -> ChunkedArray: + """Horizontally join columns to *seed* a unique key per row combination.""" + dtype = fn.string_type(native.schema.types) + it = fn.cast_table(native, dtype).itercolumns() + concat: Incomplete = pc.binary_join_element_wise + join = options.join_replace_nulls() + return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] + + +def partition_by( + native: pa.Table, by: Sequence[str], *, include_key: bool = True +) -> Iterator[pa.Table]: + if len(by) == 1: + yield from _partition_by_one(native, by[0], include_key=include_key) + else: + yield from _partition_by_many(native, by, include_key=include_key) + + +def _partition_by_one( + native: pa.Table, by: str, *, include_key: bool = True +) -> Iterator[pa.Table]: + """Optimized path for single-column partition.""" + arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode")) + indices: pa.Int32Array = arr_dict.indices + if not include_key: + native = native.remove_column(native.schema.get_field_index(by)) + for idx in range(len(arr_dict.dictionary)): + # NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"` + # Is there any reasonable way to do this in Acero? + yield native.filter(pc.equal(pa.scalar(idx), indices)) + + +def _partition_by_many( + native: pa.Table, by: Sequence[str], *, include_key: bool = True +) -> Iterator[pa.Table]: + original_names = native.column_names + temp_name = temp.column_name(original_names) + key = acero.col(temp_name) + composite_values = _composite_key(acero.select_names_table(native, by)) + # Need to iterate over the whole thing, so py_list first should be faster + unique_py = composite_values.unique().to_pylist() + re_keyed = native.add_column(0, temp_name, composite_values) + source = acero.table_source(re_keyed) + if include_key: + keep = original_names + else: + ignore = {*by, temp_name} + keep = [name for name in original_names if name not in ignore] + select = acero.select_names(keep) + for v in unique_py: + # NOTE: May want to split the `Declaration` production iterator into it's own function + # E.g, to push down column selection to *before* collection + # Not needed for this task though + yield acero.collect(source, acero.filter(key == v), select) diff --git a/narwhals/_plan/arrow/options.py b/narwhals/_plan/arrow/options.py index 8998b288a2..d0257c8c41 100644 --- a/narwhals/_plan/arrow/options.py +++ b/narwhals/_plan/arrow/options.py @@ -80,11 +80,12 @@ def _generate_agg() -> Mapping[type[agg.AggExpr], acero.AggregateOptions]: def _generate_function() -> Mapping[type[ir.Function], acero.AggregateOptions]: - from narwhals._plan.expressions import boolean + from narwhals._plan.expressions import boolean, functions return { boolean.All: scalar_aggregate(ignore_nulls=True), boolean.Any: scalar_aggregate(ignore_nulls=True), + functions.NullCount: count("only_null"), } diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 63333a49d4..10d2a82144 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -70,10 +70,18 @@ def __call__( def __call__( self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any ) -> ChunkedOrScalar[ScalarRT_co]: ... + @overload + def __call__( + self, data: Array[ScalarPT_contra], *args: Any, **kwds: Any + ) -> Array[ScalarRT_co]: ... + @overload + def __call__( + self, data: ChunkedOrArray[ScalarPT_contra], *args: Any, **kwds: Any + ) -> ChunkedOrArray[ScalarRT_co]: ... def __call__( - self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any - ) -> ChunkedOrScalar[ScalarRT_co]: ... + self, data: Arrow[ScalarPT_contra], *args: Any, **kwds: Any + ) -> Arrow[ScalarRT_co]: ... class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): @@ -130,6 +138,8 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedArrayAny: TypeAlias = "ChunkedArray[Any]" ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" +ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) +Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8ac2084034..b2e3c415b0 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -22,16 +22,22 @@ from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.series import Series - from narwhals._plan.typing import DTypeT, NonNestedDTypeT, OneOrIterable, Seq + from narwhals._plan.typing import ( + ColumnNameOrSelector, + DTypeT, + NonNestedDTypeT, + OneOrIterable, + Seq, + ) from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral T = TypeVar("T") -if sys.version_info >= (3, 13): +if sys.version_info >= (3, 13): # pragma: no cover from copy import replace as replace # noqa: PLC0414 -else: +else: # pragma: no cover def replace(obj: T, /, **changes: Any) -> T: cls = obj.__class__ @@ -85,8 +91,12 @@ def flatten_hash_safe( iterable: Iterable[OneOrIterable[CompliantSeries]], / ) -> Iterator[CompliantSeries]: ... @overload +def flatten_hash_safe( + iterable: Iterable[OneOrIterable[ColumnNameOrSelector]], / +) -> Iterator[ColumnNameOrSelector]: ... +@overload def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: ... -def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: +def flatten_hash_safe(iterable: Iterable[OneOrIterable[Any]], /) -> Iterator[Any]: """Fully unwrap all levels of nesting. Aiming to reduce the chances of passing an unhashable argument. @@ -95,23 +105,23 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: if isinstance(element, Iterable) and not is_iterable_reject(element): yield from flatten_hash_safe(element) else: - yield element # type: ignore[misc] + yield element -def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: +def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: # pragma: no cover msg = f"Expected one or an iterable of strings, but got: {qualified_type_name(obj)!r}\n{obj!r}" return TypeError(msg) def ensure_seq_str(obj: OneOrIterable[str], /) -> Seq[str]: if not isinstance(obj, Iterable): - raise _not_one_or_iterable_str_error(obj) + raise _not_one_or_iterable_str_error(obj) # pragma: no cover return (obj,) if isinstance(obj, str) else tuple(obj) def ensure_list_str(obj: OneOrIterable[str], /) -> list[str]: if not isinstance(obj, Iterable): - raise _not_one_or_iterable_str_error(obj) + raise _not_one_or_iterable_str_error(obj) # pragma: no cover return [obj] if isinstance(obj, str) else list(obj) @@ -124,7 +134,7 @@ def _reprlib_repr_backport() -> reprlib.Repr: # but also a useful constructor https://github.com/python/cpython/issues/94343 import reprlib - if sys.version_info >= (3, 12): + if sys.version_info >= (3, 12): # pragma: no cover return reprlib.Repr(indent=4, maxlist=10) else: # pragma: no cover # noqa: RET505 obj = reprlib.Repr() @@ -246,7 +256,7 @@ def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: available_chars = n_chars - len_prefix if available_chars < 0: visualize = "" - else: + else: # pragma: no cover (has coverage, but there's randomness in the test) okay = "✔" * available_chars bad = "✖" * (cls._MIN_RANDOM_CHARS - available_chars) visualize = f"\n Preview: '{prefix}{okay}{bad}'" diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 45728fce47..8808ede824 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -55,7 +55,7 @@ def native(self) -> NativeFrameT_co: ... def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ... @property def columns(self) -> list[str]: ... - def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ... + def drop(self, columns: Sequence[str]) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... # Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around def filter(self, predicate: NamedIR, /) -> Self: ... @@ -64,7 +64,7 @@ def rename(self, mapping: Mapping[str, str]) -> Self: ... def schema(self) -> Mapping[str, DType]: ... def select(self, irs: Seq[NamedIR]) -> Self: ... def select_names(self, *column_names: str) -> Self: ... - def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: ... + def sort(self, by: Sequence[str], options: SortMultipleOptions) -> Self: ... def with_columns(self, irs: Seq[NamedIR]) -> Self: ... @@ -129,6 +129,9 @@ def join( suffix: str = "_right", ) -> Self: ... def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... + def partition_by( + self, by: Sequence[str], *, include_key: bool = True + ) -> list[Self]: ... def row(self, index: int) -> tuple[Any, ...]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index e8d0dffbed..8e1b8e0e88 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -92,6 +92,9 @@ def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Sel def rolling_expr( self, node: ir.RollingExpr, frame: FrameT_contra, name: str ) -> Self: ... + def shift( + self, node: FunctionExpr[F.Shift], frame: FrameT_contra, name: str + ) -> Self: ... def ternary_expr( self, node: ir.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... @@ -99,6 +102,24 @@ def ternary_expr( def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def diff( + self, node: FunctionExpr[F.Diff], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_count( + self, node: FunctionExpr[F.CumCount], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_min( + self, node: FunctionExpr[F.CumMin], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_max( + self, node: FunctionExpr[F.CumMax], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_prod( + self, node: FunctionExpr[F.CumProd], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_sum( + self, node: FunctionExpr[F.CumSum], frame: FrameT_contra, name: str + ) -> Self: ... # series -> scalar def all( self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str @@ -139,6 +160,9 @@ def min( def n_unique( self, node: agg.NUnique, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... + def null_count( + self, node: FunctionExpr[F.NullCount], frame: FrameT_contra, name: str + ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... def quantile( self, node: agg.Quantile, frame: FrameT_contra, name: str ) -> CompliantScalar[FrameT_contra, SeriesT_co]: ... diff --git a/narwhals/_plan/compliant/group_by.py b/narwhals/_plan/compliant/group_by.py index 8e05144393..adac2bb402 100644 --- a/narwhals/_plan/compliant/group_by.py +++ b/narwhals/_plan/compliant/group_by.py @@ -12,7 +12,7 @@ FrameT_co, ResolverT_co, ) -from narwhals.exceptions import ComputeError +from narwhals._plan.exceptions import group_by_no_keys_error if TYPE_CHECKING: from collections.abc import Iterator @@ -51,8 +51,7 @@ def keys(self) -> Seq[NamedIR]: def key_names(self) -> Seq[str]: if names := self._key_names: return names - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) + raise group_by_no_keys_error() class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): @@ -163,8 +162,7 @@ def key_names(self) -> Seq[str]: return names if keys := self.keys: return tuple(e.name for e in keys) - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) + raise group_by_no_keys_error() def requires_projection(self, *, allow_aliasing: bool = False) -> bool: """Return True is group keys contain anything that is not a column selection. @@ -203,3 +201,13 @@ class Grouped(Grouper[Resolved]): @property def _resolver(self) -> type[Resolved]: return Resolved + + @classmethod + def by_irs(cls, *by: ExprIR) -> Self: + obj = cls.__new__(cls) + obj._keys = by + return obj + + def agg_irs(self, *aggs: ExprIR) -> Self: + self._aggs = aggs + return self diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 25c07d7de7..bdab2e03e6 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -11,7 +11,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct - from narwhals._plan.expressions.functions import EwmMean + from narwhals._plan.expressions.functions import EwmMean, NullCount, Shift from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -98,9 +98,20 @@ def min(self, node: agg.Min, frame: FrameT_contra, name: str) -> Self: def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: return self.from_python(1, name, dtype=None, version=self.version) + def null_count( + self, node: FunctionExpr[NullCount], frame: FrameT_contra, name: str + ) -> Self: + """Returns 1 if null, else 0.""" + ... + def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self: return self._cast_float(node.expr, frame, name) + def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> Self: + if node.function.n == 0: + return self._with_evaluated(self._evaluated, name) + return self.from_python(None, name, dtype=None, version=self.version) + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/contexts.py b/narwhals/_plan/contexts.py deleted file mode 100644 index 773b699df9..0000000000 --- a/narwhals/_plan/contexts.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -import enum - -__all__ = ["ExprContext"] - - -class ExprContext(enum.Enum): - """A [context] to evaluate expressions in. - - [context]: https://docs.pola.rs/user-guide/concepts/expressions-and-contexts/#contexts - """ - - SELECT = "select" - """The output schema has the same order and length as the (expanded) input expressions. - - That order is determined during expansion of selectors in an earlier step. - """ - - WITH_COLUMNS = "with_columns" - """The output schema *derives from* the input schema, but *may* produce a different shape. - - - Expressions producing **new names** are appended to the end of the schema - - Expressions producing **existing names** will replace the existing column positionally - """ - - def is_select(self) -> bool: - return self is ExprContext.SELECT - - def is_with_columns(self) -> bool: - return self is ExprContext.WITH_COLUMNS diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 625f1990e0..ab6a9d85cc 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -3,8 +3,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, get_args, overload from narwhals._plan import _parse -from narwhals._plan._expansion import prepare_projection +from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection from narwhals._plan.common import ensure_seq_str, temp +from narwhals._plan.exceptions import group_by_no_keys_error from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.options import SortMultipleOptions from narwhals._plan.series import Series @@ -49,7 +50,7 @@ def version(self) -> Version: return self._version @property - def implementation(self) -> Implementation: + def implementation(self) -> Implementation: # pragma: no cover return self._compliant.implementation @property @@ -60,7 +61,7 @@ def schema(self) -> Schema: def columns(self) -> list[str]: return self._compliant.columns - def __repr__(self) -> str: # pragma: no cover + def __repr__(self) -> str: return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: @@ -69,12 +70,12 @@ def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: return type(self)(compliant) - def to_native(self) -> NativeFrameT_co: + def to_native(self) -> NativeFrameT_co: # pragma: no cover return self._compliant.native def filter( self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any - ) -> Self: + ) -> Self: # pragma: no cover e = _parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) named_irs, _ = prepare_projection((e,), schema=self) if len(named_irs) != 1: @@ -104,21 +105,32 @@ def sort( descending: OneOrIterable[bool] = False, nulls_last: OneOrIterable[bool] = False, ) -> Self: - sort = _parse.parse_sort_by_into_seq_of_expr_ir(by, *more_by) + s_irs = _parse.parse_into_seq_of_selector_ir(by, *more_by) + names = expand_selector_irs_names(s_irs, schema=self) opts = SortMultipleOptions.parse(descending=descending, nulls_last=nulls_last) - named_irs, _ = prepare_projection(sort, schema=self) - return self._with_compliant(self._compliant.sort(named_irs, opts)) + return self._with_compliant(self._compliant.sort(names, opts)) - def drop(self, *columns: str, strict: bool = True) -> Self: - return self._with_compliant(self._compliant.drop(columns, strict=strict)) + def drop( + self, *columns: OneOrIterable[ColumnNameOrSelector], strict: bool = True + ) -> Self: + s_ir = _parse.parse_into_combined_selector_ir(*columns, require_all=strict) + names = expand_selector_irs_names((s_ir,), schema=self) + return self._with_compliant(self._compliant.drop(names)) - def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: - subset = [subset] if isinstance(subset, str) else subset + def drop_nulls( + self, subset: OneOrIterable[ColumnNameOrSelector] | None = None + ) -> Self: + if subset is not None: + s_irs = _parse.parse_into_seq_of_selector_ir(subset) + subset = expand_selector_irs_names(s_irs, schema=self) return self._with_compliant(self._compliant.drop_nulls(subset)) def rename(self, mapping: Mapping[str, str]) -> Self: return self._with_compliant(self._compliant.rename(mapping)) + def collect_schema(self) -> Schema: + return self.schema + class DataFrame( BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT] @@ -129,7 +141,7 @@ class DataFrame( def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation - def __len__(self) -> int: + def __len__(self) -> int: # pragma: no cover return len(self._compliant) @property @@ -182,17 +194,17 @@ def to_dict( def to_dict( self, *, as_series: bool = True ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: - if as_series: + if as_series: # pragma: no cover return { key: self._series(value) for key, value in self._compliant.to_dict(as_series=as_series).items() } return self._compliant.to_dict(as_series=as_series) - def to_series(self, index: int = 0) -> Series[NativeSeriesT]: + def to_series(self, index: int = 0) -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.to_series(index)) - def get_column(self, name: str) -> Series[NativeSeriesT]: + def get_column(self, name: str) -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.get_column(name)) @overload @@ -252,12 +264,25 @@ def filter( **constraints, ) named_irs, _ = prepare_projection((e,), schema=self) - if len(named_irs) != 1: + if len(named_irs) != 1: # pragma: no cover # Should be unreachable, but I guess we will see msg = f"Expected a single predicate after expansion, but got {len(named_irs)!r}\n\n{named_irs!r}" raise ValueError(msg) return self._with_compliant(self._compliant.filter(named_irs[0])) + def partition_by( + self, + by: OneOrIterable[ColumnNameOrSelector], + *more_by: ColumnNameOrSelector, + include_key: bool = True, + ) -> list[Self]: + by_selectors = _parse.parse_into_seq_of_selector_ir(by, *more_by) + names = expand_selector_irs_names(by_selectors, schema=self) + if not names: + raise group_by_no_keys_error() + partitions = self._compliant.partition_by(names, include_key=include_key) + return [self._with_compliant(p) for p in partitions] + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index cfeb87644b..75ecd02a87 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -6,6 +6,7 @@ from itertools import groupby from typing import TYPE_CHECKING +from narwhals._utils import qualified_type_name from narwhals.exceptions import ( ColumnNotFoundError, ComputeError, @@ -18,12 +19,9 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Collection, Iterable from typing import Any - import pandas as pd - import polars as pl - from narwhals._plan import expressions as ir from narwhals._plan._function import Function from narwhals._plan.expressions.operators import Operator @@ -54,35 +52,41 @@ def hist_bins_monotonic_error(bins: Seq[float]) -> ComputeError: # noqa: ARG001 return ComputeError(msg) -# NOTE: Always underlining `right`, since the message refers to both types of exprs -# Assuming the most recent as the issue +def _binary_underline( + left: ir.ExprIR, + operator: Operator, + right: ir.ExprIR, + /, + *, + underline_right: bool = True, +) -> str: + lhs, op, rhs = repr(left), repr(operator), repr(right) + if underline_right: + indent = (len(lhs) + len(op) + 2) * " " + underline = len(rhs) * "^" + else: + indent = "" + underline = len(lhs) * "^" + return f"{lhs} {op} {rhs}\n{indent}{underline}" + + def binary_expr_shape_error( left: ir.ExprIR, op: Operator, right: ir.ExprIR ) -> ShapeError: - lhs_op = f"{left!r} {op!r} " - rhs = repr(right) - indent = len(lhs_op) * " " - underline = len(rhs) * "^" + expr = _binary_underline(left, op, right, underline_right=True) msg = ( - f"Cannot combine length-changing expressions with length-preserving ones.\n" - f"{lhs_op}{rhs}\n{indent}{underline}" + f"Cannot combine length-changing expressions with length-preserving ones.\n{expr}" ) return ShapeError(msg) -# TODO @dangotbanned: Share the right underline code w/ `binary_expr_shape_error` def binary_expr_multi_output_error( - left: ir.ExprIR, op: Operator, right: ir.ExprIR + origin: ir.BinaryExpr, left_expand: Seq[ir.ExprIR], right_expand: Seq[ir.ExprIR] ) -> MultiOutputExpressionError: - lhs_op = f"{left!r} {op!r} " - rhs = repr(right) - indent = len(lhs_op) * " " - underline = len(rhs) * "^" - msg = ( - "Multi-output expressions are only supported on the " - f"left-hand side of a binary operation.\n" - f"{lhs_op}{rhs}\n{indent}{underline}" - ) + len_left, len_right = len(left_expand), len(right_expand) + lhs, op, rhs = origin.left, origin.op, origin.right + expr = _binary_underline(lhs, op, rhs, underline_right=len_left < len_right) + msg = f"Cannot combine selectors that produce a different number of columns ({len_left} != {len_right}).\n{expr}" return MultiOutputExpressionError(msg) @@ -135,6 +139,24 @@ def over_row_separable_error( return InvalidOperationError(msg) +def over_order_by_names_error( + expr: ir.OrderedWindowExpr, by: ir.ExprIR +) -> InvalidOperationError: + if by.meta.is_column_selection(allow_aliasing=True): + # narwhals dev error + msg = ( + f"Cannot use `{type(expr).__name__}.order_by_names()` before expression expansion.\n" + f"Found unresolved selection {by!r}, in:\n{expr!r}" + ) + else: + # user error + msg = ( + f"Only column selection expressions are supported in `over(order_by=...)`.\n" + f"Found {by!r}, in:\n{expr!r}" + ) + return InvalidOperationError(msg) + + def invalid_into_expr_error( first_input: Iterable[IntoExpr], more_inputs: tuple[Any, ...], @@ -150,19 +172,9 @@ def invalid_into_expr_error( return InvalidIntoExprError(msg) -def is_iterable_pandas_error(obj: pd.DataFrame | pd.Series[Any], /) -> TypeError: +def is_iterable_error(obj: object, /) -> TypeError: msg = ( - f"Expected Narwhals class or scalar, got: {type(obj)}. " - "Perhaps you forgot a `nw.from_native` somewhere?" - ) - return TypeError(msg) - - -def is_iterable_polars_error( - obj: pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame, / -) -> TypeError: - msg = ( - f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n" + f"Expected Narwhals class or scalar, got: {qualified_type_name(obj)!r}.\n\n" "Hint: Perhaps you\n" "- forgot a `nw.from_native` somewhere?\n" "- used `pl.col` instead of `nw.col`?" @@ -170,9 +182,10 @@ def is_iterable_polars_error( return TypeError(msg) -def duplicate_error(exprs: Seq[ir.ExprIR]) -> DuplicateError: +def duplicate_error(exprs: Collection[ir.ExprIR]) -> DuplicateError: INDENT = "\n " # noqa: N806 names = [_output_name(expr) for expr in exprs] + exprs = sorted(exprs, key=_output_name) duplicates = {k for k, v in Counter(names).items() if v > 1} group_by_name = groupby(exprs, _output_name) name_exprs = { @@ -204,9 +217,14 @@ def column_not_found_error( def column_index_error( index: int, schema_or_column_names: Iterable[str], / -) -> ComputeError: +) -> ColumnNotFoundError: # NOTE: If the original expression used a negative index, we should use that as well n_names = len(tuple(schema_or_column_names)) max_nth = f"`nth({n_names - 1})`" if index >= 0 else f"`nth(-{n_names})`" msg = f"Invalid column index {index!r}\nHint: The schema's last column is {max_nth}" + return ColumnNotFoundError(msg) + + +def group_by_no_keys_error() -> ComputeError: + msg = "at least one key is required in a group_by operation" return ComputeError(msg) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index f1dea8ac80..5dfa9e3770 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -2,10 +2,10 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, overload +from typing import TYPE_CHECKING, Any, ClassVar from narwhals._plan import common, expressions as ir -from narwhals._plan._guards import is_column, is_expr, is_series +from narwhals._plan._guards import is_expr, is_series from narwhals._plan._parse import ( parse_into_expr_ir, parse_into_seq_of_expr_ir, @@ -17,7 +17,6 @@ functions as F, operators as ops, ) -from narwhals._plan.expressions.selectors import by_name from narwhals._plan.options import ( EWMOptions, RankOptions, @@ -29,7 +28,7 @@ from narwhals.exceptions import ComputeError if TYPE_CHECKING: - from typing_extensions import Never, Self + from typing_extensions import Self from narwhals._plan._function import Function from narwhals._plan.expressions.categorical import ExprCatNamespace @@ -51,8 +50,6 @@ ) -# NOTE: Overly simplified placeholders for mocking typing -# Entirely ignoring namespace + function binding class Expr: _ir: ir.ExprIR _version: ClassVar[Version] = Version.MAIN @@ -83,8 +80,10 @@ def alias(self, name: str) -> Self: def cast(self, dtype: IntoDType) -> Self: return self._from_ir(self._ir.cast(common.into_dtype(dtype))) - def exclude(self, *names: OneOrIterable[str]) -> Self: - return self._from_ir(ir.Exclude.from_names(self._ir, *names)) + def exclude(self, *names: OneOrIterable[str]) -> Expr: + from narwhals._plan import selectors as cs + + return (self.meta.as_selector() - cs.by_name(*names)).as_expr() def count(self) -> Self: return self._from_ir(agg.Count(expr=self._ir)) @@ -251,16 +250,16 @@ def clip( it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) - def cum_count(self, *, reverse: bool = False) -> Self: + def cum_count(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumCount(reverse=reverse)) - def cum_min(self, *, reverse: bool = False) -> Self: + def cum_min(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumMin(reverse=reverse)) - def cum_max(self, *, reverse: bool = False) -> Self: + def cum_max(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumMax(reverse=reverse)) - def cum_prod(self, *, reverse: bool = False) -> Self: + def cum_prod(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumProd(reverse=reverse)) def cum_sum(self, *, reverse: bool = False) -> Self: @@ -268,7 +267,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center) return self._with_unary(F.RollingSum(options=options)) @@ -285,7 +284,7 @@ def rolling_var( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingVar(options=options)) @@ -296,7 +295,7 @@ def rolling_std( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingStd(options=options)) @@ -536,14 +535,14 @@ def name(self) -> ExprNameNamespace: >>> >>> renamed = nw.col("a", "b").name.suffix("_changed") >>> str(renamed._ir) - "RenameAlias(expr=Columns(names=[a, b]), function=Suffix(suffix='_changed'))" + "RenameAlias(expr=RootSelector(selector=ByName(names=[a, b], require_all=True)), function=Suffix(suffix='_changed'))" """ from narwhals._plan.expressions.name import ExprNameNamespace return ExprNameNamespace(_expr=self) @property - def cat(self) -> ExprCatNamespace: + def cat(self) -> ExprCatNamespace: # pragma: no cover from narwhals._plan.expressions.categorical import ExprCatNamespace return ExprCatNamespace(_expr=self) @@ -561,7 +560,7 @@ def dt(self) -> ExprDateTimeNamespace: return ExprDateTimeNamespace(_expr=self) @property - def list(self) -> ExprListNamespace: + def list(self) -> ExprListNamespace: # pragma: no cover from narwhals._plan.expressions.lists import ExprListNamespace return ExprListNamespace(_expr=self) @@ -573,111 +572,5 @@ def str(self) -> ExprStringNamespace: return ExprStringNamespace(_expr=self) -class Selector(Expr): - _ir: ir.SelectorIR - - def __repr__(self) -> str: - return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" - - @classmethod - def _from_ir(cls, selector_ir: ir.SelectorIR, /) -> Self: # type: ignore[override] - obj = cls.__new__(cls) - obj._ir = selector_ir - return obj - - def _to_expr(self) -> Expr: - return self._ir.to_narwhals(self.version) - - @overload # type: ignore[override] - def __or__(self, other: Self) -> Self: ... - @overload - def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.Or() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() | other - - @overload # type: ignore[override] - def __and__(self, other: Self) -> Self: ... - @overload - def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - other = by_name(name) - if isinstance(other, type(self)): - op = ops.And() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() & other - - @overload # type: ignore[override] - def __sub__(self, other: Self) -> Self: ... - @overload - def __sub__(self, other: IntoExpr) -> Expr: ... - def __sub__(self, other: IntoExpr) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.Sub() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() - other - - @overload # type: ignore[override] - def __xor__(self, other: Self) -> Self: ... - @overload - def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if isinstance(other, type(self)): - op = ops.ExclusiveOr() - return self._from_ir(op.to_binary_selector(self._ir, other._ir)) - return self._to_expr() ^ other - - def __invert__(self) -> Self: - return self._from_ir(ir.InvertSelector(selector=self._ir)) - - def __add__(self, other: Any) -> Expr: # type: ignore[override] - if isinstance(other, type(self)): - msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" - raise TypeError(msg) - return self._to_expr() + other # type: ignore[no-any-return] - - def __radd__(self, other: Any) -> Never: - msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" - raise TypeError(msg) - - def __rsub__(self, other: Any) -> Never: - msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" - raise TypeError(msg) - - @overload # type: ignore[override] - def __rand__(self, other: Self) -> Self: ... - @overload - def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) & self - return self._to_expr().__rand__(other) - - @overload # type: ignore[override] - def __ror__(self, other: Self) -> Self: ... - @overload - def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) | self - return self._to_expr().__ror__(other) - - @overload # type: ignore[override] - def __rxor__(self, other: Self) -> Self: ... - @overload - def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... - def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: - if is_column(other) and (name := other.meta.output_name()): - return by_name(name) ^ self - return self._to_expr().__rxor__(other) - - class ExprV1(Expr): _version: ClassVar[Version] = Version.V1 - - -class SelectorV1(Selector): - _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expressions/__init__.py b/narwhals/_plan/expressions/__init__.py index 1bdae0224f..0d97c20288 100644 --- a/narwhals/_plan/expressions/__init__.py +++ b/narwhals/_plan/expressions/__init__.py @@ -22,21 +22,16 @@ from narwhals._plan.expressions.aggregation import AggExpr, OrderableAggExpr, max, min from narwhals._plan.expressions.expr import ( Alias, - All, AnonymousExpr, BinaryExpr, BinarySelector, Cast, Column, - Columns, - Exclude, Filter, FunctionExpr, - IndexColumns, InvertSelector, Len, Literal, - Nth, OrderedWindowExpr, RangeExpr, RollingExpr, @@ -45,11 +40,7 @@ SortBy, TernaryExpr, WindowExpr, - _ColumnSelection, # if needs exposing, make it public! col, - cols, - index_columns, - nth, ) from narwhals._plan.expressions.name import KeepName, RenameAlias from narwhals._plan.expressions.window import over, over_ordered @@ -57,25 +48,20 @@ __all__ = [ "AggExpr", "Alias", - "All", "AnonymousExpr", "BinaryExpr", "BinarySelector", "Cast", "Column", - "Columns", - "Exclude", "ExprIR", "Filter", "Function", "FunctionExpr", - "IndexColumns", "InvertSelector", "KeepName", "Len", "Literal", "NamedIR", - "Nth", "OrderableAggExpr", "OrderedWindowExpr", "RangeExpr", @@ -87,19 +73,15 @@ "SortBy", "TernaryExpr", "WindowExpr", - "_ColumnSelection", "aggregation", "boolean", "categorical", "col", - "cols", "functions", - "index_columns", "lists", "max", "min", "named_ir", - "nth", "operators", "over", "over_ordered", diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 92a563f586..ed51f4cb90 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -25,10 +25,15 @@ def __repr__(self) -> str: def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() - def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: - if expr.is_scalar: - raise agg_scalar_error(self, expr) - super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] + # NOTE: Interacting badly with `pyright` synthesizing the `__replace__` signature + if not TYPE_CHECKING: + + def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: + if expr.is_scalar: + raise agg_scalar_error(self, expr) + super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] + else: # pragma: no cover + ... # fmt: off diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index ebc2a8643b..f6042a391d 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -3,12 +3,13 @@ # NOTE: Needed to avoid naming collisions # - Any import typing as t +from typing import TYPE_CHECKING from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar -if t.TYPE_CHECKING: +if TYPE_CHECKING: from typing_extensions import Self from narwhals._plan._expr_ir import ExprIR diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 7c59fd4443..5bb7157f5d 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -20,7 +20,7 @@ class IRCatNamespace(IRNamespace): class ExprCatNamespace(ExprNamespace[IRCatNamespace]): @property def _ir_namespace(self) -> type[IRCatNamespace]: - return IRCatNamespace + return IRCatNamespace # pragma: no cover def get_categories(self) -> Expr: - return self._with_unary(self._ir.get_categories()) + return self._with_unary(self._ir.get_categories()) # pragma: no cover diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 4fa2f6cf6e..62c211500e 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -2,16 +2,20 @@ from __future__ import annotations -# NOTE: Needed to avoid naming collisions -# - Literal import typing as t +from typing import TYPE_CHECKING from narwhals._plan._expr_ir import ExprIR, SelectorIR -from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.exceptions import function_expr_invalid_operation_error +from narwhals._plan.common import replace +from narwhals._plan.exceptions import ( + function_expr_invalid_operation_error, + over_order_by_names_error, +) +from narwhals._plan.expressions import selectors as cs from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT_co, + Ignored, LeftSelectorT, LeftT, LiteralT, @@ -26,33 +30,31 @@ ) from narwhals.exceptions import InvalidOperationError -if t.TYPE_CHECKING: +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + from typing_extensions import Self from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expressions.functions import MapBatches # noqa: F401 from narwhals._plan.expressions.literal import LiteralValue - from narwhals._plan.expressions.selectors import Selector from narwhals._plan.expressions.window import Window from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.schema import FrozenSchema from narwhals.dtypes import DType + from narwhals.typing import IntoDType __all__ = [ "Alias", - "All", "AnonymousExpr", "BinaryExpr", "BinarySelector", "Cast", "Column", - "Columns", - "Exclude", "Filter", "FunctionExpr", - "IndexColumns", "Len", "Literal", - "Nth", "RollingExpr", "RootSelector", "SelectorIR", @@ -68,18 +70,6 @@ def col(name: str, /) -> Column: return Column(name=name) -def cols(*names: str) -> Columns: - return Columns(names=names) - - -def nth(index: int, /) -> Nth: - return Nth(index=index) - - -def index_columns(*indices: int) -> IndexColumns: - return IndexColumns(indices=indices) - - class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "name") expr: ExprIR @@ -100,53 +90,8 @@ class Column(ExprIR, config=ExprIROptions.namespaced("col")): def __repr__(self) -> str: return f"col({self.name!r})" - -class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): - """Nodes which can resolve to `Column`(s) with a `Schema`.""" - - -class Columns(_ColumnSelection): - __slots__ = ("names",) - names: Seq[str] - - def __repr__(self) -> str: - return f"cols({list(self.names)!r})" - - -class Nth(_ColumnSelection): - __slots__ = ("index",) - index: int - - def __repr__(self) -> str: - return f"nth({self.index})" - - -class IndexColumns(_ColumnSelection): - __slots__ = ("indices",) - indices: Seq[int] - - def __repr__(self) -> str: - return f"index_columns({self.indices!r})" - - -class All(_ColumnSelection): - def __repr__(self) -> str: - return "all()" - - -class Exclude(_ColumnSelection, child=("expr",)): - __slots__ = ("expr", "names") - expr: ExprIR - """Default is `all()`.""" - names: Seq[str] - """Excluded names.""" - - @staticmethod - def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: - return Exclude(expr=expr, names=tuple(flatten_hash_safe(names))) - - def __repr__(self) -> str: - return f"{self.expr!r}.exclude({list(self.names)!r})" + def to_selector_ir(self) -> RootSelector: + return cs.ByName.from_name(self.name).to_selector_ir() class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): @@ -299,19 +244,28 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: """ for e in self.input[:1]: yield from e.iter_output_name() - - def __init__( - self, - *, - input: Seq[ExprIR], # noqa: A002 - function: FunctionT_co, - options: FunctionOptions, - **kwds: t.Any, - ) -> None: - parent = input[0] - if parent.is_scalar and not options.is_elementwise(): - raise function_expr_invalid_operation_error(function, parent) - super().__init__(**dict(input=input, function=function, options=options, **kwds)) + # NOTE: Covering the empty case doesn't make sense without implementing `FunctionFlags.ALLOW_EMPTY_INPUTS` + # https://github.com/pola-rs/polars/blob/df69276daf5d195c8feb71eef82cbe9804e0f47f/crates/polars-plan/src/plans/options.rs#L106-L107 + return # pragma: no cover + + # NOTE: Interacting badly with `pyright` synthesizing the `__replace__` signature + if not TYPE_CHECKING: + + def __init__( + self, + *, + input: Seq[ExprIR], # noqa: A002 + function: FunctionT_co, + options: FunctionOptions, + **kwds: t.Any, + ) -> None: + parent = input[0] + if parent.is_scalar and not options.is_elementwise(): + raise function_expr_invalid_operation_error(function, parent) + kwargs = dict(input=input, function=function, options=options, **kwds) + super().__init__(**kwargs) + else: # pragma: no cover + ... def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str @@ -432,6 +386,19 @@ def iter_root_names(self) -> t.Iterator[ExprIR]: yield from e.iter_left() yield self + def order_by_names(self) -> Iterator[str]: + """Yield the names resolved from expanding `order_by`. + + Raises: + InvalidOperationError: If used *before* expansion, or + `order_by` contains expressions that do more than select. + """ + for by in self.order_by: + if isinstance(by, Column): + yield by.name + else: + raise over_order_by_names_error(self, by) + class Len(ExprIR, config=ExprIROptions.namespaced()): @property @@ -446,17 +413,46 @@ def __repr__(self) -> str: return "len()" +class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): + """When-Then-Otherwise.""" + + __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 + predicate: ExprIR + truthy: ExprIR + falsy: ExprIR + + @property + def is_scalar(self) -> bool: + return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar + + def __repr__(self) -> str: + return ( + f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" + ) + + def iter_output_name(self) -> t.Iterator[ExprIR]: + yield from self.truthy.iter_output_name() + + class RootSelector(SelectorIR): """A single selector expression.""" __slots__ = ("selector",) - selector: Selector + selector: cs.Selector def __repr__(self) -> str: return f"{self.selector!r}" - def matches_column(self, name: str, dtype: DType) -> bool: - return self.selector.matches_column(name, dtype) + def iter_expand_names( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + yield from self.selector.iter_expand(schema, ignored_columns) + + def matches(self, dtype: IntoDType) -> bool: + return self.selector.to_dtype_selector().matches(dtype) + + def to_dtype_selector(self) -> Self: + return replace(self, selector=self.selector.to_dtype_selector()) class BinarySelector( @@ -466,11 +462,36 @@ class BinarySelector( ): """Application of two selector exprs via a set operator.""" - def matches_column(self, name: str, dtype: DType) -> bool: - left = self.left.matches_column(name, dtype) - right = self.right.matches_column(name, dtype) + def iter_expand_names( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + # by_name, by_index (upstream) lose their ability to reorder when used as a binary op + # (As designed) https://github.com/pola-rs/polars/issues/19384 + names = schema.names + left = frozenset(self.left.iter_expand_names(schema, ignored_columns)) + right = frozenset(self.right.iter_expand_names(schema, ignored_columns)) + remaining: frozenset[str] = self.op(left, right) + target: Iterable[str] + if remaining: + target = ( + names + if len(remaining) == len(names) + else (nm for nm in names if nm in remaining) + ) + else: + target = () + yield from target + + def matches(self, dtype: IntoDType) -> bool: + left = self.left.matches(dtype) + right = self.right.matches(dtype) return bool(self.op(left, right)) + def to_dtype_selector(self) -> Self: + return replace( + self, left=self.left.to_dtype_selector(), right=self.right.to_dtype_selector() + ) + class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) @@ -479,26 +500,27 @@ class InvertSelector(SelectorIR, t.Generic[SelectorT]): def __repr__(self) -> str: return f"~{self.selector!r}" - def matches_column(self, name: str, dtype: DType) -> bool: - return not self.selector.matches_column(name, dtype) - - -class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): - """When-Then-Otherwise.""" - - __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 - predicate: ExprIR - truthy: ExprIR - falsy: ExprIR - - @property - def is_scalar(self) -> bool: - return self.predicate.is_scalar and self.truthy.is_scalar and self.falsy.is_scalar + def iter_expand_names( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + # by_name, by_index (upstream) lose their ability to reorder when used as a binary op + # that includes invert, which is implemented as Difference(All, Selector) + # (As designed) https://github.com/pola-rs/polars/issues/19384 + names = schema.names + ignore = frozenset(self.selector.iter_expand_names(schema, ignored_columns)) + target: Iterable[str] + if ignore: + target = ( + () + if len(ignore) == len(names) + else (nm for nm in names if nm not in ignore) + ) + else: + target = names + yield from target - def __repr__(self) -> str: - return ( - f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" - ) + def matches(self, dtype: IntoDType) -> bool: + return not self.selector.to_dtype_selector().matches(dtype) - def iter_output_name(self) -> t.Iterator[ExprIR]: - yield from self.truthy.iter_output_name() + def to_dtype_selector(self) -> Self: + return replace(self, selector=self.selector.to_dtype_selector()) diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 604e054a5e..b14090b985 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -21,7 +21,7 @@ class IRListNamespace(IRNamespace): class ExprListNamespace(ExprNamespace[IRListNamespace]): @property def _ir_namespace(self) -> type[IRListNamespace]: - return IRListNamespace + return IRListNamespace # pragma: no cover def len(self) -> Expr: - return self._with_unary(self._ir.len()) + return self._with_unary(self._ir.len()) # pragma: no cover diff --git a/narwhals/_plan/expressions/operators.py b/narwhals/_plan/expressions/operators.py index 9ecc45737a..ebf3d979b9 100644 --- a/narwhals/_plan/expressions/operators.py +++ b/narwhals/_plan/expressions/operators.py @@ -7,7 +7,6 @@ from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( binary_expr_length_changing_error, - binary_expr_multi_output_error, binary_expr_shape_error, ) @@ -46,8 +45,6 @@ def to_binary_expr( ) -> BinaryExpr[LeftT, Self, RightT]: from narwhals._plan.expressions.expr import BinaryExpr - if right.meta.has_multiple_outputs(): - raise binary_expr_multi_output_error(left, self, right) if _is_filtration(left): if _is_filtration(right): raise binary_expr_length_changing_error(left, self, right) diff --git a/narwhals/_plan/expressions/selectors.py b/narwhals/_plan/expressions/selectors.py index d09d7af761..c741ad5632 100644 --- a/narwhals/_plan/expressions/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -1,93 +1,278 @@ -"""Deviations from `polars`. - -- A `Selector` corresponds to a `nw.selectors` function -- Binary ops are represented as a `BinarySelector`, similar to `BinaryExpr`. -""" - from __future__ import annotations +import functools import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ClassVar, final from narwhals._plan._immutable import Immutable from narwhals._plan.common import flatten_hash_safe -from narwhals._utils import Version, _parse_time_unit_and_time_zone +from narwhals._plan.exceptions import column_index_error, column_not_found_error +from narwhals._utils import ( + Version, + _parse_time_unit_and_time_zone, + isinstance_or_issubclass, +) +from narwhals.dtypes import DType, FloatType, IntegerType, NumericType, TemporalType +from narwhals.typing import IntoDType, TimeUnit if TYPE_CHECKING: + from collections.abc import Iterator from datetime import timezone - from typing import TypeVar - from narwhals._plan import expr + import narwhals.dtypes as nw_dtypes + from narwhals._plan.expressions import SelectorIR from narwhals._plan.expressions.expr import RootSelector - from narwhals._plan.typing import OneOrIterable - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit + from narwhals._plan.schema import FrozenSchema + from narwhals._plan.typing import Ignored, OneOrIterable, Seq + - T = TypeVar("T") +_dtypes = Version.MAIN.dtypes -dtypes = Version.MAIN.dtypes +_ALL_TIME_UNITS = frozenset[TimeUnit](("ms", "us", "ns", "s")) class Selector(Immutable): - def to_selector(self) -> RootSelector: + def __repr__(self) -> str: + return f"ncs.{type(self).__name__.lower()}()" + + def to_selector_ir(self) -> RootSelector: from narwhals._plan.expressions.expr import RootSelector return RootSelector(selector=self) - def matches_column(self, name: str, dtype: DType) -> bool: - raise NotImplementedError(type(self)) + def to_dtype_selector(self) -> DTypeSelector: + msg = f"expected datatype based expression got {self!r}" + raise TypeError(msg) + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + """Yield column names that match the selector, in `schema` order[^1]. -class All(Selector): + Adapted from [upstream]. + + Arguments: + schema: Target scope to expand the selector in. + ignored_columns: Names of `group_by` columns, which are excluded[^2] from the result. + + Note: + [^1]: `ByName`, `ByIndex` return their inputs in given order not in schema order. + + Note: + [^2]: `ByName`, `ByIndex` will never be ignored. + + [upstream]: https://github.com/pola-rs/polars/blob/2b241543851800595efd343be016b65cdbdd3c9f/crates/polars-plan/src/dsl/selector.rs#L188-L198 + """ + msg = f"{type(self).__name__}.iter_expand" + raise NotImplementedError(msg) + + +class DTypeSelector(Selector): + # https://github.com/pola-rs/polars/blob/2b241543851800595efd343be016b65cdbdd3c9f/crates/polars-plan/src/dsl/selector.rs#L110-L172 + _dtype: ClassVar[type[DType]] + + def __init_subclass__(cls, *args: Any, dtype: type[DType], **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + cls._dtype = dtype + + def to_dtype_selector(self) -> DTypeSelector: + return self + + @final + def matches(self, dtype: IntoDType) -> bool: + """Return True if we can select this dtype. + + Important: + The result will *only* be cached if this method is **not overridden**. + Instead, use `DTypeSelector._matches` to customize the check. + """ + return _selector_matches(self, dtype) + + def _matches(self, dtype: IntoDType) -> bool: + """Implementation of `DTypeSelector.matches`.""" + return isinstance_or_issubclass(dtype, self._dtype) + + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + if ignored_columns: + for name, dtype in schema.items(): + if self.matches(dtype) and name not in ignored_columns: + yield name + else: + yield from (name for name, dtype in schema.items() if self.matches(dtype)) + + +class DTypeAll(DTypeSelector, dtype=DType): def __repr__(self) -> str: return "ncs.all()" - def matches_column(self, name: str, dtype: DType) -> bool: + def _matches(self, dtype: IntoDType) -> bool: return True -class ByDType(Selector): - __slots__ = ("dtypes",) - dtypes: frozenset[DType | type[DType]] +class All(Selector): + def to_dtype_selector(self) -> DTypeSelector: + return DTypeAll() + + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + if ignored_columns: + yield from (name for name in schema if name not in ignored_columns) + else: + yield from schema - @staticmethod - def from_dtypes(*dtypes: OneOrIterable[DType | type[DType]]) -> ByDType: - return ByDType(dtypes=frozenset(flatten_hash_safe(dtypes))) + +class ByIndex(Selector): + __slots__ = ("indices", "require_all") + indices: Seq[int] + require_all: bool def __repr__(self) -> str: - els = ", ".join( - tp.__name__ if isinstance(tp, type) else repr(tp) for tp in self.dtypes + if len(self.indices) == 1 and self.indices[0] in {0, -1}: + name = "first" if self.indices[0] == 0 else "last" + return f"ncs.{name}()" + return f"ncs.by_index({list(self.indices)}, require_all={self.require_all})" + + @staticmethod + def _iter_validate(indices: tuple[OneOrIterable[int], ...], /) -> Iterator[int]: + for idx in flatten_hash_safe(indices): + if not isinstance(idx, int): + msg = f"invalid index value: {idx!r}" + raise TypeError(msg) + yield idx + + @staticmethod + def from_indices(*indices: OneOrIterable[int], require_all: bool = True) -> ByIndex: + return ByIndex( + indices=tuple(ByIndex._iter_validate(indices)), require_all=require_all ) - return f"ncs.by_dtype(dtypes=[{els}])" - def matches_column(self, name: str, dtype: DType) -> bool: - return dtype in self.dtypes + @staticmethod + def from_index(index: int, /, *, require_all: bool = True) -> ByIndex: + return ByIndex(indices=(index,), require_all=require_all) + + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + names = schema.names + n_fields = len(names) + if not self.require_all: + if n_fields == 0: + yield from () + else: + yield from (names[idx] for idx in self.indices if abs(idx) < n_fields) + else: + for idx in self.indices: + if abs(idx) < n_fields: + yield names[idx] + else: + raise column_index_error(idx, schema) + + +class ByName(Selector): + __slots__ = ("names", "require_all") + names: Seq[str] + require_all: bool + def __repr__(self) -> str: + els = ", ".join(f"{nm!r}" for nm in self.names) + return f"ncs.by_name({els}, require_all={self.require_all})" + + @staticmethod + def _iter_validate(names: tuple[OneOrIterable[str], ...], /) -> Iterator[str]: + for name in flatten_hash_safe(names): + if not isinstance(name, str): + msg = f"invalid name: {name!r}" + raise TypeError(msg) + yield name + + @staticmethod + def from_names(*names: OneOrIterable[str], require_all: bool = True) -> ByName: + return ByName(names=tuple(ByName._iter_validate(names)), require_all=require_all) + + @staticmethod + def from_name(name: str, /, *, require_all: bool = True) -> ByName: + return ByName(names=(name,), require_all=require_all) + + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + if not self.require_all: + keys = schema.keys() + yield from (name for name in self.names if name in keys) + else: + if not set(schema).issuperset(self.names): + raise column_not_found_error(self.names, schema) + yield from self.names + + +class Matches(Selector): + __slots__ = ("pattern",) + pattern: re.Pattern[str] + + @staticmethod + def from_string(pattern: str, /) -> Matches: + return Matches(pattern=re.compile(pattern)) -class Boolean(Selector): def __repr__(self) -> str: - return "ncs.boolean()" + return f"ncs.matches({self.pattern.pattern!r})" + + def iter_expand( + self, schema: FrozenSchema, ignored_columns: Ignored + ) -> Iterator[str]: + search = self.pattern.search + if ignored_columns: + for name in schema: + if name not in ignored_columns and search(name): + yield name + else: + yield from (name for name in schema if search(name)) + + +class Array(DTypeSelector, dtype=_dtypes.Array): + __slots__ = ("inner", "size") + inner: SelectorIR | None + size: int | None + + def __repr__(self) -> str: + inner = "" if not self.inner else repr(self.inner) + size = self.size or "*" + return f"ncs.array({inner}, size={size})" + + def _matches(self, dtype: IntoDType) -> bool: + return ( + isinstance(dtype, _dtypes.Array) + and _inner_selector_matches(self, dtype) + and (self.size is None or dtype.size == self.size) + ) + - def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.Boolean) +class Boolean(DTypeSelector, dtype=_dtypes.Boolean): ... -class Categorical(Selector): +class ByDType(DTypeSelector, dtype=DType): + __slots__ = ("dtypes",) + dtypes: frozenset[DType | type[DType]] + def __repr__(self) -> str: - return "ncs.categorical()" + if not self.dtypes: + return "ncs.empty()" + return f"ncs.by_dtype([{', '.join(sorted(map(repr, self.dtypes)))}])" - def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.Categorical) + def _matches(self, dtype: DType | type[DType]) -> bool: + return dtype in self.dtypes + @staticmethod + def empty() -> ByDType: + return ByDType(dtypes=frozenset()) -class Datetime(Selector): - """Should swallow the [`utils` functions]. - Just re-wrapping them for now, since `CompliantSelectorNamespace` is still using them. +class Categorical(DTypeSelector, dtype=_dtypes.Categorical): ... - [`utils` functions]: https://github.com/narwhals-dev/narwhals/blob/6d524ba04fca6fe2d6d25bdd69f75fabf1d79039/narwhals/utils.py#L1565-L1596 - """ +class Datetime(DTypeSelector, dtype=_dtypes.Datetime): __slots__ = ("time_units", "time_zones") time_units: frozenset[TimeUnit] time_zones: frozenset[str | None] @@ -102,12 +287,14 @@ def from_time_unit_and_time_zone( return Datetime(time_units=frozenset(units), time_zones=frozenset(zones)) def __repr__(self) -> str: - return f"ncs.datetime(time_unit={list(self.time_units)}, time_zone={list(self.time_zones)})" + time_unit = "*" if self.time_units == _ALL_TIME_UNITS else list(self.time_units) + time_zone = "*" if self.time_zones == {"*", None} else list(self.time_zones) + return f"ncs.datetime(time_unit={time_unit}, time_zone={time_zone})" - def matches_column(self, name: str, dtype: DType) -> bool: + def _matches(self, dtype: IntoDType) -> bool: units, zones = self.time_units, self.time_zones return ( - isinstance(dtype, dtypes.Datetime) + isinstance_or_issubclass(dtype, _dtypes.Datetime) and (dtype.time_unit in units) and ( dtype.time_zone in zones or ("*" in zones and dtype.time_zone is not None) @@ -115,81 +302,72 @@ def matches_column(self, name: str, dtype: DType) -> bool: ) -class Matches(Selector): - __slots__ = ("pattern",) - pattern: re.Pattern[str] - - @staticmethod - def from_string(pattern: str, /) -> Matches: - return Matches(pattern=re.compile(pattern)) +class Duration(DTypeSelector, dtype=_dtypes.Duration): + __slots__ = ("time_units",) + time_units: frozenset[TimeUnit] @staticmethod - def from_names(*names: OneOrIterable[str]) -> Matches: - """Implements `cs.by_name` to support `__r__` with column selections.""" - it = flatten_hash_safe(names) - return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$") + def from_time_unit(time_unit: OneOrIterable[TimeUnit] | None, /) -> Duration: + if time_unit is None: + units = _ALL_TIME_UNITS + elif not isinstance(time_unit, str): + units = frozenset(time_unit) + else: + units = frozenset((time_unit,)) + return Duration(time_units=units) def __repr__(self) -> str: - return f"ncs.matches(pattern={self.pattern.pattern!r})" - - def matches_column(self, name: str, dtype: DType) -> bool: - return bool(self.pattern.search(name)) - + time_unit = "*" if self.time_units == _ALL_TIME_UNITS else list(self.time_units) + return f"ncs.duration(time_unit={time_unit})" -class Numeric(Selector): - def __repr__(self) -> str: - return "ncs.numeric()" + def _matches(self, dtype: IntoDType) -> bool: + return isinstance_or_issubclass(dtype, _dtypes.Duration) and ( + dtype.time_unit in self.time_units + ) - def matches_column(self, name: str, dtype: DType) -> bool: - return dtype.is_numeric() +class Enum(DTypeSelector, dtype=_dtypes.Enum): ... -class String(Selector): - def __repr__(self) -> str: - return "ncs.string()" - def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.String) +class Float(DTypeSelector, dtype=FloatType): ... -def all() -> expr.Selector: - return All().to_selector().to_narwhals() +class Integer(DTypeSelector, dtype=IntegerType): ... -def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector: - return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() +class List(DTypeSelector, dtype=_dtypes.List): + __slots__ = ("inner",) + inner: SelectorIR | None + def __repr__(self) -> str: + inner = "" if not self.inner else repr(self.inner) + return f"ncs.list({inner})" -def by_name(*names: OneOrIterable[str]) -> expr.Selector: - return Matches.from_names(*names).to_selector().to_narwhals() + def _matches(self, dtype: IntoDType) -> bool: + return isinstance(dtype, _dtypes.List) and _inner_selector_matches(self, dtype) -def boolean() -> expr.Selector: - return Boolean().to_selector().to_narwhals() +class Numeric(DTypeSelector, dtype=NumericType): ... -def categorical() -> expr.Selector: - return Categorical().to_selector().to_narwhals() +class String(DTypeSelector, dtype=_dtypes.String): ... -def datetime( - time_unit: OneOrIterable[TimeUnit] | None = None, - time_zone: OneOrIterable[str | timezone | None] = ("*", None), -) -> expr.Selector: - return ( - Datetime.from_time_unit_and_time_zone(time_unit, time_zone) - .to_selector() - .to_narwhals() - ) +class Struct(DTypeSelector, dtype=_dtypes.Struct): ... -def matches(pattern: str) -> expr.Selector: - return Matches.from_string(pattern).to_selector().to_narwhals() +class Temporal(DTypeSelector, dtype=TemporalType): ... -def numeric() -> expr.Selector: - return Numeric().to_selector().to_narwhals() +@functools.lru_cache(maxsize=128) +def _selector_matches(selector: DTypeSelector, dtype: IntoDType, /) -> bool: + # `DTypeSelector.matches` (uncached) + # -> `_selector_matches` (cached) + # -> `DTypeSelector._matches` (impl) + return selector._matches(dtype) -def string() -> expr.Selector: - return String().to_selector().to_narwhals() +def _inner_selector_matches( + selector: Array | List, dtype: nw_dtypes.Array | nw_dtypes.List +) -> bool: + return selector.inner is None or selector.inner.matches(dtype.inner) diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 6e60a7b530..5478a7154c 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -84,30 +84,32 @@ class IRStringNamespace(IRNamespace): def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Replace: + ) -> Replace: # pragma: no cover return Replace(pattern=pattern, value=value, literal=literal, n=n) def replace_all( self, pattern: str, value: str, *, literal: bool = False - ) -> ReplaceAll: + ) -> ReplaceAll: # pragma: no cover return ReplaceAll(pattern=pattern, value=value, literal=literal) - def strip_chars(self, characters: str | None = None) -> StripChars: + def strip_chars( + self, characters: str | None = None + ) -> StripChars: # pragma: no cover return StripChars(characters=characters) def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) - def slice(self, offset: int, length: int | None = None) -> Slice: + def slice(self, offset: int, length: int | None = None) -> Slice: # pragma: no cover return Slice(offset=offset, length=length) - def head(self, n: int = 5) -> Slice: + def head(self, n: int = 5) -> Slice: # pragma: no cover return self.slice(0, n) - def tail(self, n: int = 5) -> Slice: + def tail(self, n: int = 5) -> Slice: # pragma: no cover return self.slice(-n) - def to_datetime(self, format: str | None = None) -> ToDatetime: + def to_datetime(self, format: str | None = None) -> ToDatetime: # pragma: no cover return ToDatetime(format=format) @@ -121,41 +123,43 @@ def len_chars(self) -> Expr: def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Expr: + ) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) - def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: + def replace_all( + self, pattern: str, value: str, *, literal: bool = False + ) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) - def strip_chars(self, characters: str | None = None) -> Expr: + def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.strip_chars(characters)) - def starts_with(self, prefix: str) -> Expr: + def starts_with(self, prefix: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.starts_with(prefix=prefix)) - def ends_with(self, suffix: str) -> Expr: + def ends_with(self, suffix: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) - def slice(self, offset: int, length: int | None = None) -> Expr: + def slice(self, offset: int, length: int | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.slice(offset, length)) - def head(self, n: int = 5) -> Expr: + def head(self, n: int = 5) -> Expr: # pragma: no cover return self._with_unary(self._ir.head(n)) - def tail(self, n: int = 5) -> Expr: + def tail(self, n: int = 5) -> Expr: # pragma: no cover return self._with_unary(self._ir.tail(n)) - def split(self, by: str) -> Expr: + def split(self, by: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.split(by=by)) - def to_datetime(self, format: str | None = None) -> Expr: + def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_datetime(format)) - def to_lowercase(self) -> Expr: + def to_lowercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_lowercase()) - def to_uppercase(self) -> Expr: + def to_uppercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_uppercase()) diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 11a87599ab..35a622ebd2 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -64,7 +64,7 @@ class Timestamp(TemporalFunction): def from_time_unit(time_unit: TimeUnit = "us", /) -> Timestamp: if not _is_polars_time_unit(time_unit): msg = f"invalid `time_unit` \n\nExpected one of ['ns', 'us', 'ms'], got {time_unit!r}." - raise ValueError(msg) + raise TypeError(msg) return Timestamp(time_unit=time_unit) def __repr__(self) -> str: @@ -115,64 +115,64 @@ class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): def _ir_namespace(self) -> type[IRDateTimeNamespace]: return IRDateTimeNamespace - def date(self) -> Expr: + def date(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.date()) - def year(self) -> Expr: + def year(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.year()) - def month(self) -> Expr: + def month(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.month()) - def day(self) -> Expr: + def day(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.day()) - def hour(self) -> Expr: + def hour(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.hour()) - def minute(self) -> Expr: + def minute(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.minute()) - def second(self) -> Expr: + def second(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.second()) - def millisecond(self) -> Expr: + def millisecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.millisecond()) - def microsecond(self) -> Expr: + def microsecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.microsecond()) - def nanosecond(self) -> Expr: + def nanosecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.nanosecond()) - def ordinal_day(self) -> Expr: + def ordinal_day(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.ordinal_day()) - def weekday(self) -> Expr: + def weekday(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.weekday()) def total_minutes(self) -> Expr: return self._with_unary(self._ir.total_minutes()) - def total_seconds(self) -> Expr: + def total_seconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_seconds()) - def total_milliseconds(self) -> Expr: + def total_milliseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_milliseconds()) - def total_microseconds(self) -> Expr: + def total_microseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_microseconds()) - def total_nanoseconds(self) -> Expr: + def total_nanoseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_nanoseconds()) - def to_string(self, format: str) -> Expr: + def to_string(self, format: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_string(format=format)) - def replace_time_zone(self, time_zone: str | None) -> Expr: + def replace_time_zone(self, time_zone: str | None) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace_time_zone(time_zone=time_zone)) - def convert_time_zone(self, time_zone: str) -> Expr: + def convert_time_zone(self, time_zone: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.convert_time_zone(time_zone=time_zone)) def timestamp(self, time_unit: TimeUnit = "us") -> Expr: diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index c07fe92c29..8e930d1044 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -2,8 +2,9 @@ import builtins import typing as t +from typing import TYPE_CHECKING -from narwhals._plan import _guards, _parse, common, expressions as ir +from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs from narwhals._plan.expressions import functions as F from narwhals._plan.expressions.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.expressions.ranges import IntRange @@ -11,7 +12,7 @@ from narwhals._plan.when_then import When from narwhals._utils import Version, flatten -if t.TYPE_CHECKING: +if TYPE_CHECKING: from narwhals._plan.expr import Expr from narwhals._plan.series import Series from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT @@ -21,14 +22,15 @@ def col(*names: str | t.Iterable[str]) -> Expr: flat = tuple(flatten(names)) - node = ir.col(flat[0]) if builtins.len(flat) == 1 else ir.cols(*flat) - return node.to_narwhals() + return ( + ir.col(flat[0]).to_narwhals() + if builtins.len(flat) == 1 + else cs.by_name(*flat).as_expr() + ) def nth(*indices: int | t.Sequence[int]) -> Expr: - flat = tuple(flatten(indices)) - node = ir.nth(flat[0]) if builtins.len(flat) == 1 else ir.index_columns(*flat) - return node.to_narwhals() + return cs.by_index(*indices).as_expr() def lit( @@ -51,11 +53,11 @@ def len() -> Expr: def all() -> Expr: - return ir.All().to_narwhals() + return cs.all().as_expr() def exclude(*names: str | t.Iterable[str]) -> Expr: - return all().exclude(*names) + return cs.all().exclude(*names).as_expr() def max(*columns: str) -> Expr: diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index bb7a4315b3..41487b503b 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -7,34 +7,38 @@ from __future__ import annotations from functools import lru_cache -from itertools import chain from typing import TYPE_CHECKING, Literal, overload from narwhals._plan import expressions as ir from narwhals._plan._guards import is_literal +from narwhals._plan.expressions import selectors as cs from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.expressions.namespace import IRNamespace -from narwhals.exceptions import ComputeError +from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator + + from narwhals._plan import Selector class MetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" def has_multiple_outputs(self) -> bool: - return any(_has_multiple_outputs(e) for e in self._ir.iter_left()) + return any(isinstance(e, ir.SelectorIR) for e in self._ir.iter_left()) def is_column(self) -> bool: return isinstance(self._ir, ir.Column) def is_column_selection(self, *, allow_aliasing: bool = False) -> bool: - return all( - _is_column_selection(e, allow_aliasing=allow_aliasing) - for e in self._ir.iter_left() - ) + nodes = self._ir.iter_left() + selection = ir.Column, ir.SelectorIR + if not allow_aliasing: + return all(isinstance(e, selection) for e in nodes) + targets = *selection, ir.Alias, ir.KeepName, ir.RenameAlias + return all(isinstance(e, targets) for e in nodes) def is_literal(self, *, allow_aliasing: bool = False) -> bool: return all( @@ -73,42 +77,25 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: def root_names(self) -> list[str]: """Get the root column names.""" - return list(_expr_to_leaf_column_names_iter(self._ir)) - - -def _expr_to_leaf_column_names_iter(expr: ir.ExprIR, /) -> Iterator[str]: - for e in _expr_to_leaf_column_exprs_iter(expr): - result = _expr_to_leaf_column_name(e) - if isinstance(result, str): - yield result - - -def _expr_to_leaf_column_exprs_iter(expr: ir.ExprIR, /) -> Iterator[ir.ExprIR]: - for outer in expr.iter_root_names(): - if isinstance(outer, (ir.Column, ir.All)): - yield outer - - -def _expr_to_leaf_column_name(expr: ir.ExprIR, /) -> str | ComputeError: - leaves = list(_expr_to_leaf_column_exprs_iter(expr)) - if not len(leaves) <= 1: - msg = "found more than one root column name" - return ComputeError(msg) - if not leaves: - msg = "no root column name found" - return ComputeError(msg) - leaf = leaves[0] - if isinstance(leaf, ir.Column): - return leaf.name - if isinstance(leaf, ir.All): - msg = "wildcard has no root column name" - return ComputeError(msg) - msg = f"Expected unreachable, got {type(leaf).__name__!r}\n\n{leaf}" - return ComputeError(msg) + return list(iter_root_names(self._ir)) + + def as_selector(self) -> Selector: + """Try to turn this expression into a selector. + + Raises if the underlying expressions is not a column or selector. + """ + return self._ir.to_selector_ir().to_narwhals() + + +def iter_root_names(expr: ir.ExprIR, /) -> Iterator[str]: + yield from (e.name for e in expr.iter_root_names() if isinstance(e, ir.Column)) -def root_names_unique(exprs: Iterable[ir.ExprIR], /) -> set[str]: - return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in exprs)) +def root_name_first(expr: ir.ExprIR, /) -> str: + if name := next(iter_root_names(expr), None): + return name + msg = f"`name.keep_name` expected at least one column name, got `{expr!r}`" + raise InvalidOperationError(msg) @lru_cache(maxsize=32) @@ -116,44 +103,22 @@ def _expr_output_name(expr: ir.ExprIR, /) -> str | ComputeError: for e in expr.iter_output_name(): if isinstance(e, (ir.Column, ir.Alias, ir.Literal, ir.Len)): return e.name - if isinstance(e, (ir.All, ir.KeepName, ir.RenameAlias)): + if isinstance(e, ir.RenameAlias): + parent = _expr_output_name(e.expr) + return e.function(parent) if isinstance(parent, str) else parent + if isinstance(e, ir.KeepName): msg = "cannot determine output column without a context for this expression" return ComputeError(msg) - if isinstance(e, (ir.Columns, ir.IndexColumns, ir.Nth)): - msg = "this expression may produce multiple output names" - return ComputeError(msg) - continue + if isinstance(e, ir.RootSelector) and ( + isinstance(e.selector, cs.ByName) and len(e.selector.names) == 1 + ): + return e.selector.names[0] msg = ( f"unable to find root column name for expr '{expr!r}' when calling 'output_name'" ) return ComputeError(msg) -def get_single_leaf_name(expr: ir.ExprIR, /) -> str | ComputeError: - """Find the name at the start of an expression. - - Normal iteration would just return the first root column it found. - - Based on [`polars_plan::utils::get_single_leaf`] - - [`polars_plan::utils::get_single_leaf`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/utils.rs#L151-L168 - """ - for e in expr.iter_right(): - if isinstance(e, (ir.WindowExpr, ir.SortBy, ir.Filter)): - return get_single_leaf_name(e.expr) - if isinstance(e, ir.BinaryExpr): - return get_single_leaf_name(e.left) - # NOTE: `polars` doesn't include `Literal` here - if isinstance(e, (ir.Column, ir.Len)): - return e.name - msg = f"unable to find a single leaf column in expr '{expr!r}'" - return ComputeError(msg) - - -def _has_multiple_outputs(expr: ir.ExprIR, /) -> bool: - return isinstance(expr, (ir.Columns, ir.IndexColumns, ir.SelectorIR, ir.All)) - - def has_expr_ir(expr: ir.ExprIR, *matches: type[ir.ExprIR]) -> bool: """Return True if any node in the tree is in type `matches`. @@ -174,9 +139,3 @@ def _is_literal(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: and isinstance(expr.expr.dtype, Version.MAIN.dtypes.Datetime) ) ) - - -def _is_column_selection(expr: ir.ExprIR, /, *, allow_aliasing: bool) -> bool: - return isinstance(expr, (ir.Column, ir._ColumnSelection, ir.SelectorIR)) or ( - allow_aliasing and isinstance(expr, (ir.Alias, ir.KeepName, ir.RenameAlias)) - ) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 739654e4bf..03550afb7c 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -59,7 +59,7 @@ def returns_scalar(self) -> bool: return FunctionFlags.RETURNS_SCALAR in self def is_length_preserving(self) -> bool: - return FunctionFlags.LENGTH_PRESERVING in self + return FunctionFlags.LENGTH_PRESERVING in self # pragma: no cover def is_row_separable(self) -> bool: return FunctionFlags.ROW_SEPARABLE in self @@ -92,7 +92,7 @@ def returns_scalar(self) -> bool: return self.flags.returns_scalar() def is_length_preserving(self) -> bool: - return self.flags.is_length_preserving() + return self.flags.is_length_preserving() # pragma: no cover def is_row_separable(self) -> bool: return self.flags.is_row_separable() @@ -102,8 +102,8 @@ def is_input_wildcard_expansion(self) -> bool: def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: - msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." - raise TypeError(msg) + msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." # pragma: no cover + raise TypeError(msg) # pragma: no cover obj = FunctionOptions.__new__(FunctionOptions) object.__setattr__(obj, "flags", self.flags | flags) return obj @@ -155,10 +155,6 @@ def __repr__(self) -> str: args = f"descending={self.descending!r}, nulls_last={self.nulls_last!r}" return f"{type(self).__name__}({args})" - @staticmethod - def default() -> SortOptions: - return SortOptions(descending=False, nulls_last=False) - def to_arrow(self) -> pc.ArraySortOptions: import pyarrow.compute as pc @@ -201,7 +197,7 @@ def _to_arrow_args( ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): - msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" + msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" # pragma: no cover raise NotImplementedError(msg) if len(self.descending) == 1: descending: Iterable[bool] = repeat(self.descending[0], len(by)) @@ -219,7 +215,9 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: sort_keys, placement = self._to_arrow_args(by) return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) - def to_arrow_acero(self, by: Sequence[str]) -> pyarrow.acero.Declaration: + def to_arrow_acero( + self, by: Sequence[str] + ) -> pyarrow.acero.Declaration: # pragma: no cover from narwhals._plan.arrow import acero sort_keys, placement = self._to_arrow_args(by) @@ -291,7 +289,7 @@ def __repr__(self) -> str: return self.__str__() @classmethod - def default(cls) -> Self: + def default(cls) -> Self: # pragma: no cover[abstract] return cls(is_namespaced=False, override_name="") @classmethod diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 10c6665d08..2c15cf7804 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -4,7 +4,7 @@ from functools import lru_cache from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, final, overload from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import Immutable @@ -12,7 +12,7 @@ from narwhals.dtypes import Unknown if TYPE_CHECKING: - from collections.abc import ItemsView, Iterator, KeysView, ValuesView + from collections.abc import ItemsView, Iterable, Iterator, KeysView, ValuesView from typing_extensions import Never, TypeAlias, TypeIs @@ -22,7 +22,7 @@ IntoFrozenSchema: TypeAlias = ( - "IntoSchema | Iterator[tuple[str, DType]] | FrozenSchema | HasSchema" + "IntoSchema | Iterable[tuple[str, DType]] | FrozenSchema | HasSchema" ) """A schema to freeze, or an already frozen one. @@ -35,6 +35,7 @@ _T2 = TypeVar("_T2") +@final class FrozenSchema(Immutable): """Use `freeze_schema(...)` constructor to trigger caching!""" @@ -42,7 +43,7 @@ class FrozenSchema(Immutable): _mapping: MappingProxyType[str, DType] def __init_subclass__(cls, *_: Never, **__: Never) -> Never: - msg = f"Cannot subclass {cls.__name__!r}" + msg = f"Cannot subclass {FrozenSchema.__name__!r}" raise TypeError(msg) def merge(self, other: FrozenSchema, /) -> FrozenSchema: @@ -69,7 +70,7 @@ def select(self, exprs: Seq[NamedIR]) -> FrozenSchema: def select_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: return exprs - def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: + def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: # pragma: no cover # similar to `merge`, but preserving known `DType`s names = (e.name for e in exprs) default = Unknown() diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py new file mode 100644 index 0000000000..65c7197895 --- /dev/null +++ b/narwhals/_plan/selectors.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import operator +from collections import deque +from functools import reduce +from typing import TYPE_CHECKING, Any, ClassVar, overload + +from narwhals._plan import expressions as ir +from narwhals._plan._guards import is_column, is_re_pattern +from narwhals._plan.common import flatten_hash_safe +from narwhals._plan.expr import Expr, ExprV1 +from narwhals._plan.expressions import operators as ops, selectors as s_ir +from narwhals._utils import Version +from narwhals.dtypes import DType + +if TYPE_CHECKING: + import re + from collections.abc import Callable, Mapping + from datetime import timezone + + from typing_extensions import Never, Self + + from narwhals._plan.typing import OneOrIterable + from narwhals.typing import TimeUnit + +__all__ = [ + "Selector", + "all", + "array", + "boolean", + "by_dtype", + "by_index", + "by_name", + "categorical", + "datetime", + "duration", + "empty", + "enum", + "first", + "float", + "integer", + "last", + "list", + "matches", + "numeric", + "string", + "struct", + "temporal", +] + +_dtypes = Version.MAIN.dtypes +_dtypes_v1 = Version.V1.dtypes + + +class Selector(Expr): + _ir: ir.SelectorIR + + def __repr__(self) -> str: + return f"nw._plan.Selector({self.version.name.lower()}):\n{self._ir!r}" + + @classmethod + def _from_ir(cls, selector_ir: ir.SelectorIR, /) -> Self: # type: ignore[override] + obj = cls.__new__(cls) + obj._ir = selector_ir + return obj + + def as_expr(self) -> Expr: + tp = Expr if self.version is Version.MAIN else ExprV1 + return tp._from_ir(self._ir) + + def exclude(self, *names: OneOrIterable[str]) -> Selector: + return self - by_name(*names) # pyright: ignore[reportReturnType] + + def __invert__(self) -> Self: + return self._from_ir(ir.InvertSelector(selector=self._ir)) + + def __add__(self, other: Any) -> Expr: # type: ignore[override] + if isinstance(other, type(self)): + return self.as_expr().__add__(other.as_expr()) + return self.as_expr().__add__(other) + + def __radd__(self, other: Any) -> Expr: # type: ignore[override] + return self.as_expr().__radd__(other) + + @overload # type: ignore[override] + def __and__(self, other: Self) -> Self: ... + @overload + def __and__(self, other: Any) -> Expr: ... + def __and__(self, other: Any) -> Self | Expr: + if is_column(other): # @polars>=2.0: remove + other = by_name(other.meta.output_name()) + if isinstance(other, type(self)): + op = ops.And() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self.as_expr().__and__(other) + + def __rand__(self, other: Any) -> Expr: # type: ignore[override] + return self.as_expr().__rand__(other) + + @overload # type: ignore[override] + def __or__(self, other: Self) -> Self: ... + @overload + def __or__(self, other: Any) -> Expr: ... + def __or__(self, other: Any) -> Self | Expr: + if is_column(other): # @polars>=2.0: remove + other = by_name(other.meta.output_name()) + if isinstance(other, type(self)): + op = ops.Or() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self.as_expr().__or__(other) + + def __ror__(self, other: Any) -> Expr: # type: ignore[override] + if is_column(other): + other = by_name(other.meta.output_name()) + return self.as_expr().__ror__(other) + + @overload # type: ignore[override] + def __sub__(self, other: Self) -> Self: ... + @overload + def __sub__(self, other: Any) -> Expr: ... + def __sub__(self, other: Any) -> Self | Expr: + if isinstance(other, type(self)): + op = ops.Sub() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self.as_expr().__sub__(other) + + def __rsub__(self, other: Any) -> Never: + msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" + raise TypeError(msg) + + @overload # type: ignore[override] + def __xor__(self, other: Self) -> Self: ... + @overload + def __xor__(self, other: Any) -> Expr: ... + def __xor__(self, other: Any) -> Self | Expr: + if is_column(other): # @polars>=2.0: remove + other = by_name(other.meta.output_name()) + if isinstance(other, type(self)): + op = ops.ExclusiveOr() + return self._from_ir(op.to_binary_selector(self._ir, other._ir)) + return self.as_expr().__xor__(other) + + def __rxor__(self, other: Any) -> Expr: # type: ignore[override] + if is_column(other): # @polars>=2.0: remove + other = by_name(other.meta.output_name()) + return self.as_expr().__rxor__(other) + + +class SelectorV1(Selector): + _version: ClassVar[Version] = Version.V1 + + +def all() -> Selector: + return s_ir.All().to_selector_ir().to_narwhals() + + +def array(inner: Selector | None = None, *, size: int | None = None) -> Selector: + s = inner._ir.to_dtype_selector() if inner is not None else None + return s_ir.Array(inner=s, size=size).to_selector_ir().to_narwhals() + + +def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> Selector: + selectors: deque[Selector] = deque() + dtypes_: deque[DType | type[DType]] = deque() + for tp in flatten_hash_safe(dtypes): + if isinstance(tp, type) and issubclass(tp, DType): + if constructor := _HASH_SENSITIVE_TO_SELECTOR.get(tp): + selectors.append(constructor()) + else: + dtypes_.append(tp) + elif isinstance(tp, DType): + dtypes_.append(tp) + else: + msg = f"invalid dtype: {tp!r}" + raise TypeError(msg) + if dtypes_: + dtype_selector = ( + s_ir.ByDType(dtypes=frozenset(dtypes_)).to_selector_ir().to_narwhals() + ) + selectors.appendleft(dtype_selector) + it = iter(selectors) + if first := next(it, None): + return reduce(operator.or_, it, first) + return empty() + + +def by_index(*indices: OneOrIterable[int], require_all: bool = True) -> Selector: + if len(indices) == 1 and isinstance(indices[0], int): + sel = s_ir.ByIndex.from_index(indices[0], require_all=require_all) + else: + sel = s_ir.ByIndex.from_indices(*indices, require_all=require_all) + return sel.to_selector_ir().to_narwhals() + + +def by_name(*names: OneOrIterable[str], require_all: bool = True) -> Selector: + if len(names) == 1 and isinstance(names[0], str): + sel = s_ir.ByName.from_name(names[0], require_all=require_all) + else: + sel = s_ir.ByName.from_names(*names, require_all=require_all) + return sel.to_selector_ir().to_narwhals() + + +def boolean() -> Selector: + return s_ir.Boolean().to_selector_ir().to_narwhals() + + +def categorical() -> Selector: + return s_ir.Categorical().to_selector_ir().to_narwhals() + + +def datetime( + time_unit: OneOrIterable[TimeUnit] | None = None, + time_zone: OneOrIterable[str | timezone | None] = ("*", None), +) -> Selector: + return ( + s_ir.Datetime.from_time_unit_and_time_zone(time_unit, time_zone) + .to_selector_ir() + .to_narwhals() + ) + + +def duration(time_unit: OneOrIterable[TimeUnit] | None = None) -> Selector: + return s_ir.Duration.from_time_unit(time_unit).to_selector_ir().to_narwhals() + + +def empty() -> Selector: + return s_ir.ByDType.empty().to_selector_ir().to_narwhals() + + +def enum() -> Selector: + return s_ir.Enum().to_selector_ir().to_narwhals() + + +def first() -> Selector: + return s_ir.ByIndex.from_index(0).to_selector_ir().to_narwhals() + + +def float() -> Selector: + return s_ir.Float().to_selector_ir().to_narwhals() + + +def integer() -> Selector: + return s_ir.Integer().to_selector_ir().to_narwhals() + + +def last() -> Selector: + return s_ir.ByIndex.from_index(-1).to_selector_ir().to_narwhals() + + +def list(inner: Selector | None = None) -> Selector: + s = inner._ir.to_dtype_selector() if inner is not None else None + return s_ir.List(inner=s).to_selector_ir().to_narwhals() + + +def matches(pattern: str | re.Pattern[str]) -> Selector: + tp = s_ir.Matches + s = tp(pattern=pattern) if is_re_pattern(pattern) else tp.from_string(pattern) + return s.to_selector_ir().to_narwhals() + + +def numeric() -> Selector: + return s_ir.Numeric().to_selector_ir().to_narwhals() + + +def string() -> Selector: + return s_ir.String().to_selector_ir().to_narwhals() + + +def struct() -> Selector: + return s_ir.Struct().to_selector_ir().to_narwhals() + + +def temporal() -> Selector: + return s_ir.Temporal().to_selector_ir().to_narwhals() + + +_HASH_SENSITIVE_TO_SELECTOR: Mapping[type[DType], Callable[[], Selector]] = { + _dtypes.Datetime: datetime, + _dtypes_v1.Datetime: datetime, + _dtypes.Duration: duration, + _dtypes_v1.Duration: duration, + _dtypes.Enum: enum, + _dtypes_v1.Enum: enum, + _dtypes.Array: array, + _dtypes.List: list, + _dtypes.Struct: struct, +} diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index d3a17f29a4..255a8442e0 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator + from typing_extensions import Self + from narwhals._plan.compliant.series import CompliantSeries from narwhals._typing import EagerAllowed, IntoBackend from narwhals.dtypes import DType @@ -54,8 +56,9 @@ def from_iterable( ) ) raise NotImplementedError(implementation) - msg = f"{implementation} support in Narwhals is lazy-only" - raise ValueError(msg) + else: # pragma: no cover # noqa: RET506 + msg = f"{implementation} support in Narwhals is lazy-only" + raise ValueError(msg) @classmethod def from_native( @@ -74,9 +77,12 @@ def to_native(self) -> NativeSeriesT_co: def to_list(self) -> list[Any]: return self._compliant.to_list() - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[Any]: # pragma: no cover yield from self.to_native() + def alias(self, name: str) -> Self: + return type(self)(self._compliant.alias(name)) + class SeriesV1(Series[NativeSeriesT_co]): _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 1449d245ac..060083ba4b 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -1,10 +1,12 @@ from __future__ import annotations import typing as t +from collections.abc import Container +from typing import TYPE_CHECKING from narwhals._typing_compat import TypeVar -if t.TYPE_CHECKING: +if TYPE_CHECKING: from collections.abc import Callable, Iterable from typing_extensions import TypeAlias @@ -14,11 +16,12 @@ from narwhals._plan._expr_ir import ExprIR, NamedIR, SelectorIR from narwhals._plan._function import Function from narwhals._plan.dataframe import DataFrame - from narwhals._plan.expr import Expr, Selector + from narwhals._plan.expr import Expr from narwhals._plan.expressions import operators as ops from narwhals._plan.expressions.functions import RollingWindow from narwhals._plan.expressions.namespace import IRNamespace from narwhals._plan.expressions.ranges import RangeFunction + from narwhals._plan.selectors import Selector from narwhals._plan.series import Series from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -26,6 +29,7 @@ "ColumnNameOrSelector", "DataFrameT", "FunctionT", + "Ignored", "IntoExpr", "IntoExprColumn", "LeftSelectorT", @@ -126,3 +130,10 @@ Order: TypeAlias = t.Literal["ascending", "descending"] NonCrossJoinStrategy: TypeAlias = t.Literal["inner", "left", "full", "semi", "anti"] PartialSeries: TypeAlias = "Callable[[Iterable[t.Any]], Series[NativeSeriesAnyT]]" + + +Ignored: TypeAlias = Container[str] +"""Names of `group_by` columns, which are excluded[^1] when expanding a `Selector`. + +[^1]: `ByName`, `ByIndex` will never be ignored. +""" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 18ae514f0d..ce51e19087 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,19 +2,31 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan._guards import is_expr from narwhals._plan._immutable import Immutable from narwhals._plan._parse import ( - parse_into_expr_ir, + parse_into_expr_ir as _parse_into_expr_ir, parse_predicates_constraints_into_expr_ir, ) from narwhals._plan.expr import Expr +from narwhals.exceptions import MultiOutputExpressionError if TYPE_CHECKING: from narwhals._plan.expressions import ExprIR, TernaryExpr from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq +def _multi_output_error(expr: ExprIR) -> MultiOutputExpressionError: + msg = f"Multi-output expressions are not supported in a `when-then-otherwise` context: `{expr!r}`" + return MultiOutputExpressionError(msg) + + +def parse_into_expr_ir(statement: IntoExpr, /) -> ExprIR: + expr_ir = _parse_into_expr_ir(statement) + if expr_ir.meta.has_multiple_outputs(): + raise _multi_output_error(expr_ir) + return expr_ir + + class When(Immutable): __slots__ = ("condition",) condition: ExprIR @@ -22,10 +34,6 @@ class When(Immutable): def then(self, expr: IntoExpr, /) -> Then: return Then(condition=self.condition, statement=parse_into_expr_ir(expr)) - @staticmethod - def _from_expr(expr: Expr, /) -> When: - return When(condition=expr._ir) - @staticmethod def _from_ir(expr_ir: ExprIR, /) -> When: return When(condition=expr_ir) @@ -58,10 +66,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] return Expr._from_ir(expr_ir) - def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] - if is_expr(value): - return super(Expr, self).__eq__(value) - return super().__eq__(value) + def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] + return Expr.__eq__(self, other) class ChainedWhen(Immutable): @@ -106,10 +112,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] return Expr._from_ir(expr_ir) - def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] - if is_expr(value): - return super(Expr, self).__eq__(value) - return super().__eq__(value) + def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] + return Expr.__eq__(self, other) def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: diff --git a/pyproject.toml b/pyproject.toml index 75cc580102..c45809ca51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -300,8 +300,10 @@ omit = [ 'narwhals/_ibis/typing.py', # Remove after finishing eager sub-protocols 'narwhals/_compliant/namespace.py', - # Doesn't have a full impl yet - 'narwhals/_plan/*' + # NOTE: Gradually adding as things become more stable + 'narwhals/_plan/arrow/*', + 'narwhals/_plan/compliant/*', + 'narwhals/_plan/**/typing.py', ] exclude_also = [ "if sys.version_info() <", @@ -321,7 +323,14 @@ exclude_also = [ 'if ".*" in str\(constructor', 'pytest.skip\(', 'assert_never\(', - 'PANDAS_VERSION < \(' + 'PANDAS_VERSION < \(', + 'def __repr__', + 'def __str__', + # Extends a `covdefaults` pattern to account for `EM10{1,2}` + # https://github.com/asottile/covdefaults/blob/a5228df597ffc7933bb2fb5b7bad94119a40a896/covdefaults.py#L90-L92 + # https://docs.astral.sh/ruff/rules/raw-string-in-exception/ + # https://docs.astral.sh/ruff/rules/f-string-in-exception/ + '^\s*msg = .+\n\s*raise NotImplementedError\b' ] [tool.mypy] @@ -357,6 +366,7 @@ module = [ "narwhals._arrow.*", "narwhals._dask.*", "narwhals._spark_like.*", + "narwhals._plan.expressions.expr" ] warn_return_any = false diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index e011e5c547..3d33764223 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -1,10 +1,13 @@ from __future__ import annotations +import re +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import pytest -from narwhals._plan import selectors as ndcs +from narwhals._plan import selectors as ncs +from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError pytest.importorskip("pyarrow") pytest.importorskip("numpy") @@ -14,13 +17,14 @@ import narwhals as nw from narwhals import _plan as nwp from narwhals._utils import Version -from narwhals.exceptions import ComputeError -from tests.plan.utils import assert_equal_data, dataframe +from tests.plan.utils import assert_equal_data, dataframe, first, last if TYPE_CHECKING: from collections.abc import Sequence + from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable from narwhals.typing import PythonLiteral + from tests.conftest import Data @pytest.fixture @@ -45,12 +49,19 @@ def data_small() -> dict[str, Any]: @pytest.fixture -def data_smaller(data_small: dict[str, Any]) -> dict[str, Any]: +def data_small_af(data_small: dict[str, Any]) -> dict[str, Any]: """Use only columns `"a"-"f"`.""" keep = {"a", "b", "c", "d", "e", "f"} return {k: v for k, v in data_small.items() if k in keep} +@pytest.fixture +def data_small_dh(data_small: dict[str, Any]) -> dict[str, Any]: + """Use only columns `"d"-"h"`.""" + keep = {"d", "e", "f", "g", "h"} + return {k: v for k, v in data_small.items() if k in keep} + + @pytest.fixture def data_indexed() -> dict[str, Any]: """Used in https://github.com/narwhals-dev/narwhals/pull/2528.""" @@ -69,12 +80,6 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: return repr(expr) -XFAIL_REWRITE_SPECIAL_ALIASES = pytest.mark.xfail( - reason="https://github.com/narwhals-dev/narwhals/blob/3732e5a6b56411157f13307dfdbd25e397d5b8e6/narwhals/_plan/meta.py#L142-L162\n" - "Matches behavior of `polars`\n" - "pl.select(pl.lit(1).name.suffix('_suffix'))", - raises=ComputeError, -) XFAIL_KLEENE_ALL_NULL = pytest.mark.xfail( reason="`pyarrow` uses `pa.null()`, which also fails in current `narwhals`.\n" "In `polars`, the same op is supported and it uses `pl.Null`.\n\n" @@ -107,11 +112,9 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: (nwp.col("b").cast(nw.Float64()), {"b": [1.0, 2.0, 3.0]}), (nwp.lit(1).cast(nw.Float64).alias("literal_cast"), {"literal_cast": [1.0]}), pytest.param( - nwp.lit(1).cast(nw.Float64()).name.suffix("_cast"), - {"literal_cast": [1.0]}, - marks=XFAIL_REWRITE_SPECIAL_ALIASES, + nwp.lit(1).cast(nw.Float64()).name.suffix("_cast"), {"literal_cast": [1.0]} ), - ([ndcs.string().first(), nwp.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), + ([ncs.string().first(), nwp.col("b")], {"a": ["A", "A", "A"], "b": [1, 2, 3]}), ( nwp.col("c", "d") .sort_by("a", "b", descending=[True, False]) @@ -381,7 +384,7 @@ def _ids_ir(expr: nwp.Expr | Any) -> str: id="map_batches-numpy", ), pytest.param( - ndcs.by_name("b", "c", "d") + ncs.by_name("b", "c", "d") .map_batches(lambda s: np.append(s.to_numpy(), [10, 2]), is_elementwise=True) .sort(), {"b": [1, 2, 2, 3, 10], "c": [2, 2, 4, 9, 10], "d": [2, 7, 8, 8, 10]}, @@ -430,7 +433,7 @@ def test_select( }, ), ( - ndcs.numeric().cast(nw.String), + ncs.numeric().cast(nw.String), { "a": ["A", "B", "A"], "b": ["1", "2", "3"], @@ -458,7 +461,7 @@ def test_select( pytest.param( [ nwp.col("a").alias("a?"), - ndcs.by_name("a"), + ncs.by_name("a"), nwp.col("b").cast(nw.Float64).name.suffix("_float"), nwp.col("c").max() + 1, nwp.sum_horizontal(1, "d", nwp.col("b"), nwp.lit(3)), @@ -481,20 +484,12 @@ def test_select( def test_with_columns( expr: nwp.Expr | Sequence[nwp.Expr], expected: dict[str, Any], - data_smaller: dict[str, Any], + data_small_af: dict[str, Any], ) -> None: - result = dataframe(data_smaller).with_columns(expr) + result = dataframe(data_small_af).with_columns(expr) assert_equal_data(result, expected) -def first(*names: str) -> nwp.Expr: - return nwp.col(*names).first() - - -def last(*names: str) -> nwp.Expr: - return nwp.col(*names).last() - - @pytest.mark.parametrize( ("agg", "expected"), [ @@ -535,6 +530,100 @@ def test_row_is_py_literal( assert result == polars_result +@pytest.mark.parametrize( + ("columns", "expected"), + [ + ("a", ["b", "c"]), + (["a"], ["b", "c"]), + (ncs.first(), ["b", "c"]), + ([ncs.first()], ["b", "c"]), + (["a", "b"], ["c"]), + (~ncs.last(), ["c"]), + ([ncs.integer() | ncs.enum()], ["c"]), + ([ncs.first(), "b"], ["c"]), + (ncs.all(), []), + ([], ["a", "b", "c"]), + (ncs.struct(), ["a", "b", "c"]), + ], +) +def test_drop(columns: OneOrIterable[ColumnNameOrSelector], expected: list[str]) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]} + df = dataframe(data) + if isinstance(columns, (str, nwp.Selector, list)): + assert df.drop(columns).collect_schema().names() == expected + else: # pragma: no cover + ... + if not isinstance(columns, str) and isinstance(columns, Iterable): + assert df.drop(*columns).collect_schema().names() == expected + + +def test_drop_strict() -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6]} + df = dataframe(data) + with pytest.raises(ColumnNotFoundError): + df.drop("z") + with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['z']")): + df.drop(ncs.last(), "z") + assert df.drop("z", strict=False).collect_schema().names() == ["a", "b"] + assert df.drop(ncs.last(), "z", strict=False).collect_schema().names() == ["a"] + + +def test_drop_nulls(data_small_dh: Data) -> None: + df = dataframe(data_small_dh) + expected: Data = {"d": [], "e": [], "f": [], "g": [], "h": []} + result = df.drop_nulls() + assert_equal_data(result, expected) + + +def test_drop_nulls_invalid(data_small_dh: Data) -> None: + df = dataframe(data_small_dh) + with pytest.raises(TypeError, match=r"cannot turn.+int.+into a selector"): + df.drop_nulls(123) # type: ignore[arg-type] + with pytest.raises( + InvalidOperationError, match=r"cannot turn.+col\('a'\).first\(\).+into a selector" + ): + df.drop_nulls(nwp.col("a").first()) # type: ignore[arg-type] + + with pytest.raises(ColumnNotFoundError): + df.drop_nulls(["j", "k"]) + + with pytest.raises(ColumnNotFoundError): + df.drop_nulls(ncs.by_name("j", "k")) + + with pytest.raises(ColumnNotFoundError): + df.drop_nulls(ncs.by_index(-999)) + + +DROP_ROW_1: Data = { + "d": [7, 8], + "e": [9, 7], + "f": [False, None], + "g": [None, False], + "h": [None, True], +} +KEEP_ROW_3: Data = {"d": [8], "e": [7], "f": [None], "g": [False], "h": [True]} + + +@pytest.mark.parametrize( + ("subset", "expected"), + [ + ("e", DROP_ROW_1), + (nwp.col("e"), DROP_ROW_1), + (ncs.by_index(1), DROP_ROW_1), + (ncs.integer(), DROP_ROW_1), + ([ncs.numeric() | ~ncs.boolean()], DROP_ROW_1), + (["g", "h"], KEEP_ROW_3), + ([ncs.by_name("g", "h"), "d"], KEEP_ROW_3), + ], +) +def test_drop_nulls_subset( + data_small_dh: Data, subset: OneOrIterable[ColumnNameOrSelector], expected: Data +) -> None: + df = dataframe(data_small_dh) + result = df.drop_nulls(subset) + assert_equal_data(result, expected) + + if TYPE_CHECKING: from typing_extensions import assert_type diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index db65d468a5..8d8f8d994e 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -10,21 +10,13 @@ from narwhals import _plan as nwp from narwhals._plan import expressions as ir, selectors as ncs from narwhals._plan._dispatch import get_dispatch_name -from tests.plan.utils import assert_equal_data, dataframe, named_ir +from tests.plan.utils import assert_equal_data, dataframe, named_ir, re_compile if TYPE_CHECKING: - import sys - import pyarrow as pa - from typing_extensions import TypeAlias from narwhals._plan.dataframe import DataFrame - if sys.version_info >= (3, 11): - _Flags: TypeAlias = "int | re.RegexFlag" - else: - _Flags: TypeAlias = int - @pytest.fixture def data() -> dict[str, Any]: @@ -41,12 +33,6 @@ def df(data: dict[str, Any]) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: return dataframe(data) -def re_compile( - pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE -) -> re.Pattern[str]: - return re.compile(pattern, flags) - - def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: implemented_full = nwp.col("a").is_null() forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) @@ -69,7 +55,7 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: with pytest.raises( TypeError, - match=re_compile(r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first"), + match=re_compile(r"RootSelector.+not.+appear.+compliant.+expand.+expr.+first"), ): df._compliant.select(forgot_to_expand) diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 203c39911b..72324aeaac 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -7,25 +7,25 @@ import narwhals as nw from narwhals import _plan as nwp -from narwhals._plan import expressions as ir, selectors as ndcs -from narwhals._plan._expansion import ( - prepare_projection, - replace_selector, - rewrite_special_aliases, +from narwhals._plan import expressions as ir, selectors as ncs +from narwhals._utils import zip_strict +from narwhals.exceptions import ( + ColumnNotFoundError, + DuplicateError, + InvalidOperationError, + MultiOutputExpressionError, ) -from narwhals._plan._parse import parse_into_seq_of_expr_ir -from narwhals._plan.schema import freeze_schema -from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError -from tests.plan.utils import assert_expr_ir_equal, named_ir +from tests.plan.utils import Frame, assert_expr_ir_equal, named_ir, re_compile if TYPE_CHECKING: from collections.abc import Iterable, Sequence from narwhals._plan.typing import IntoExpr, MapIR from narwhals.dtypes import DType + from narwhals.typing import IntoSchema -@pytest.fixture +@pytest.fixture(scope="module") def schema_1() -> dict[str, DType]: return { "a": nw.Int64(), @@ -51,16 +51,21 @@ def schema_1() -> dict[str, DType]: } +@pytest.fixture(scope="module") +def df_1(schema_1: IntoSchema) -> Frame: + return Frame.from_mapping(schema_1) + + MULTI_OUTPUT_EXPRS = ( pytest.param(nwp.col("a", "b", "c")), - pytest.param(ndcs.numeric() - ndcs.matches("[d-j]")), + pytest.param(ncs.numeric() - ncs.matches("[d-j]")), pytest.param(nwp.nth(0, 1, 2)), - pytest.param(ndcs.by_dtype(nw.Int64, nw.Int32, nw.Int16)), - pytest.param(ndcs.by_name("a", "b", "c")), + pytest.param(ncs.by_dtype(nw.Int64, nw.Int32, nw.Int16)), + pytest.param(ncs.by_name("a", "b", "c")), ) """All of these resolve to `["a", "b", "c"]`.""" -BIG_EXCLUDE = ("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") +BIG_EXCLUDE = nwp.exclude("k", "l", "m", "n", "o", "p", "s", "u", "r", "a", "b", "e", "q") def udf_name_map(name: str) -> str: @@ -94,7 +99,7 @@ def udf_name_map(name: str) -> str: ), "HELLO", ), - ( + pytest.param( ( nwp.col("start") .alias("next") @@ -104,24 +109,35 @@ def udf_name_map(name: str) -> str: .alias("noise") .name.suffix("_end") ), - "start_end", + "noise_end", ), ], ) -def test_rewrite_special_aliases_single(expr: nwp.Expr, expected: str) -> None: - # NOTE: We can't use `output_name()` without resolving these rewrites - # Once they're done, `output_name()` just peeks into `Alias(name=...)` - ir_input = expr._ir - with pytest.raises(ComputeError): - ir_input.meta.output_name() - - ir_output = rewrite_special_aliases(ir_input) - assert ir_input != ir_output - actual = ir_output.meta.output_name() - assert actual == expected - - -def alias_replace_guarded(name: str) -> MapIR: # pragma: no cover +def test_special_aliases_single(expr: nwp.Expr, expected: str) -> None: + df = Frame.from_names( + "a", + "B", + "c", + "d", + "aBcD EFg hi", + "hello", + "start", + "ignore me", + "ignore me as well", + "unreferenced", + ) + df.assert_selects(expr, expected) + + +def test_keep_name_no_names(df_1: Frame) -> None: + with pytest.raises( + InvalidOperationError, + match=r"expected at least one.+name.+got.+lit.+1.+name.keep", + ): + df_1.project(nwp.lit(1).name.keep()) + + +def alias_replace_guarded(name: str) -> MapIR: """Guards against repeatedly creating the same alias.""" def fn(e_ir: ir.ExprIR) -> ir.ExprIR: @@ -132,7 +148,7 @@ def fn(e_ir: ir.ExprIR) -> ir.ExprIR: return fn -def alias_replace_unguarded(name: str) -> MapIR: # pragma: no cover +def alias_replace_unguarded(name: str) -> MapIR: """**Does not guard against recursion**! Handling the recursion stopping **should be** part of the impl of `ExprIR.map_ir`. @@ -183,75 +199,51 @@ def test_map_ir_recursive(expr: nwp.Expr, function: MapIR, expected: nwp.Expr) - assert_expr_ir_equal(actual, expected) -@pytest.mark.parametrize( - ("expr", "expected"), - [ - (nwp.col("a"), nwp.col("a")), - (nwp.col("a").max().alias("z"), nwp.col("a").max().alias("z")), - (ndcs.string(), ir.Columns(names=("k",))), - ( - ndcs.by_dtype(nw.Datetime("ms"), nw.Date, nw.List(nw.String)), - nwp.col("n", "s"), - ), - (ndcs.string() | ndcs.boolean(), nwp.col("k", "m")), - ( - ~(ndcs.numeric() | ndcs.string()), - nwp.col("l", "m", "n", "o", "p", "q", "r", "s", "u"), - ), - ( - ( - ndcs.all() - - (ndcs.categorical() | ndcs.by_name("a", "b") | ndcs.matches("[fqohim]")) - ^ ndcs.by_name("u", "a", "b", "d", "e", "f", "g") - ).name.suffix("_after"), - nwp.col("a", "b", "c", "f", "j", "k", "l", "n", "r", "s").name.suffix( - "_after" - ), - ), - ( - (ndcs.matches("[a-m]") & ~ndcs.numeric()).sort(nulls_last=True).first() - != nwp.lit(None), - nwp.col("k", "l", "m").sort(nulls_last=True).first() != nwp.lit(None), - ), - ( - ( - ndcs.numeric() - .mean() - .over("k", order_by=ndcs.by_dtype(nw.Date()) | ndcs.boolean()) - ), - ( - nwp.col("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") - .mean() - .over(nwp.col("k"), order_by=nwp.col("m", "n")) - ), - ), - ( - ( - ndcs.datetime() - .dt.timestamp() - .min() - .over(ndcs.string() | ndcs.boolean()) - .last() - .name.to_uppercase() - ), - ( - nwp.col("l", "o") - .dt.timestamp("us") - .min() - .over(nwp.col("k", "m")) - .last() - .name.to_uppercase() - ), - ), - ], -) -def test_replace_selector( - expr: nwp.Selector | nwp.Expr, - expected: nwp.Expr | ir.ExprIR, - schema_1: dict[str, DType], -) -> None: - actual = replace_selector(expr._ir, schema=freeze_schema(**schema_1)) - assert_expr_ir_equal(actual, expected) +def test_expand_selectors_funky_1(df_1: Frame) -> None: + # root->selection->transform + selector = ncs.matches("[a-m]") & ~ncs.numeric() + expr = selector.sort(nulls_last=True).first() != nwp.lit(None) + expecteds = [ + named_ir(name, nwp.col(name).sort(nulls_last=True).first() != nwp.lit(None)) + for name in ("k", "l", "m") + ] + actuals = df_1.project(expr) + for actual, expected in zip_strict(actuals, expecteds): + assert_expr_ir_equal(actual, expected) + + +def test_expand_selectors_funky_2(df_1: Frame) -> None: + # root->selection->transform + # leaf->selection + expr = ncs.numeric().mean().over("k", order_by=ncs.by_dtype(nw.Date) | ncs.boolean()) + root_names = "a", "b", "c", "d", "e", "f", "g", "h", "i", "j" + expecteds = ( + named_ir(name, nwp.col(name).mean().over("k", order_by=("m", "n"))) + for name in root_names + ) + actuals = df_1.project(expr) + for actual, expected in zip_strict(actuals, expecteds): + assert_expr_ir_equal(actual, expected) + + +def test_expand_selectors_funky_3(df_1: Frame) -> None: + # root->selection->transform->rename + # leaf->selection + expr = ( + ncs.datetime() + .dt.timestamp() + .min() + .over(ncs.string() | ncs.boolean()) + .last() + .name.to_uppercase() + ) + expecteds = [ + named_ir(name.upper(), nwp.col(name).dt.timestamp().min().over("k", "m").last()) + for name in ("l", "o") + ] + actuals = df_1.project(expr) + for actual, expected in zip_strict(actuals, expecteds): + assert_expr_ir_equal(actual, expected) @pytest.mark.parametrize( @@ -261,13 +253,11 @@ def test_replace_selector( pytest.param( nwp.col("b", "c", "d"), [nwp.col("b"), nwp.col("c"), nwp.col("d")], - id="Columns", + id="ByName(3)", ), - pytest.param(nwp.nth(6), [nwp.col("g")], id="Nth"), + pytest.param(nwp.nth(6), [nwp.col("g")], id="ByIndex(1)"), pytest.param( - nwp.nth(9, 8, -5), - [nwp.col("j"), nwp.col("i"), nwp.col("p")], - id="IndexColumns", + nwp.nth(9, 8, -5), [nwp.col("j"), nwp.col("i"), nwp.col("p")], id="ByIndex(3)" ), pytest.param( [nwp.nth(2).alias("c again"), nwp.nth(-1, -2).name.to_uppercase()], @@ -276,7 +266,7 @@ def test_replace_selector( named_ir("U", nwp.col("u")), named_ir("S", nwp.col("s")), ], - id="Nth-Alias-IndexColumns-Uppercase", + id="ByIndex(1)-Alias-ByIndex(2)-Uppercase", ), pytest.param( nwp.all(), @@ -305,7 +295,7 @@ def test_replace_selector( id="All", ), pytest.param( - (ndcs.numeric() - ndcs.by_dtype(nw.Float32(), nw.Float64())) + (ncs.numeric() - ncs.by_dtype(nw.Float32(), nw.Float64())) .cast(nw.Int64) .mean() .name.suffix("_mean"), @@ -328,8 +318,7 @@ def test_replace_selector( ), pytest.param( ( - (ndcs.numeric() ^ (ndcs.matches(r"[abcdg]") | ndcs.by_name("i", "f"))) - * 100 + (ncs.numeric() ^ (ncs.matches(r"[abcdg]") | ncs.by_name("i", "f"))) * 100 ).name.suffix("_mult_100"), [ named_ir("e_mult_100", (nwp.col("e") * nwp.lit(100))), @@ -339,7 +328,7 @@ def test_replace_selector( id="Selector-XOR-OR-BinaryExpr-Suffix", ), pytest.param( - ndcs.by_dtype(nw.Duration()) + ncs.by_dtype(nw.Duration()) .dt.total_minutes() .name.map(lambda nm: f"total_mins: {nm!r} ?"), [named_ir("total_mins: 'q' ?", nwp.col("q").dt.total_minutes())], @@ -361,7 +350,7 @@ def test_replace_selector( nwp.col("g").cast(nw.String).str.starts_with("1").all(), ), ], - id="Cast-StartsWith-All-Suffix", + id="ByName(2)-Cast-StartsWith-All-Suffix", ), pytest.param( nwp.col("a", "b") @@ -382,10 +371,10 @@ def test_replace_selector( .over(nwp.col("c"), nwp.col("e"), order_by=[nwp.col("d")]), ), ], - id="First-Over-Partitioned-Ordered-Suffix", + id="ByName(2)-First-Over-Partitioned-Ordered-Suffix", ), pytest.param( - nwp.exclude(BIG_EXCLUDE), + BIG_EXCLUDE, [ nwp.col("c"), nwp.col("d"), @@ -398,7 +387,7 @@ def test_replace_selector( id="Exclude", ), pytest.param( - nwp.exclude(BIG_EXCLUDE).name.suffix("_2"), + BIG_EXCLUDE.name.suffix("_2"), [ named_ir("c_2", nwp.col("c")), named_ir("d_2", nwp.col("d")), @@ -411,7 +400,7 @@ def test_replace_selector( id="Exclude-Suffix", ), pytest.param( - nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ndcs.string()), + nwp.col("c").alias("c_min_over_order_by").min().over(order_by=ncs.string()), [ named_ir( "c_min_over_order_by", @@ -421,7 +410,7 @@ def test_replace_selector( id="Alias-Min-Over-Order-By-Selector", ), pytest.param( - (ndcs.by_name("a", "b", "c") / nwp.col("e").first()) + (ncs.by_name("a", "b", "c") / nwp.col("e").first()) .over("g", "f", order_by="f") .name.prefix("hi_"), [ @@ -440,15 +429,41 @@ def test_replace_selector( ], id="Selector-BinaryExpr-Over-Prefix", ), + pytest.param( + [ + nwp.col("c").sort_by(nwp.col("c", "i")).first().alias("ByName(2)"), + nwp.col("c").sort_by("c", "i").first().alias("Column_x2"), + ], + [ + named_ir( + "ByName(2)", nwp.col("c").sort_by(nwp.col("c"), nwp.col("i")).first() + ), + named_ir( + "Column_x2", nwp.col("c").sort_by(nwp.col("c"), nwp.col("i")).first() + ), + ], + id="SortBy-ByName", + ), + pytest.param( + nwp.nth(1).mean().over("k", order_by=nwp.nth(4, 5)), + [ + nwp.col("b") + .mean() + .over(nwp.col("k"), order_by=(nwp.col("e"), nwp.col("f"))) + ], + id="Over-OrderBy-ByIndex(2)", + ), + pytest.param( + nwp.col("f").max().over(ncs.by_dtype(nw.Date, nw.Datetime)), + [nwp.col("f").max().over(nwp.col("l"), nwp.col("n"), nwp.col("o"))], + id="Over-Partitioned-Selector", + ), ], ) def test_prepare_projection( - into_exprs: IntoExpr | Sequence[IntoExpr], - expected: Sequence[nwp.Expr], - schema_1: dict[str, DType], + into_exprs: IntoExpr | Sequence[IntoExpr], expected: Sequence[nwp.Expr], df_1: Frame ) -> None: - irs_in = parse_into_seq_of_expr_ir(into_exprs) - actual, _ = prepare_projection(irs_in, schema=schema_1) + actual = df_1.project(into_exprs) assert len(actual) == len(expected) for lhs, rhs in zip(actual, expected): assert_expr_ir_equal(lhs, rhs) @@ -460,20 +475,20 @@ def test_prepare_projection( nwp.all(), nwp.nth(1, 2, 3), nwp.col("a", "b", "c"), - ndcs.boolean() | ndcs.categorical(), - (ndcs.by_name("a", "b") | ndcs.string()), + ncs.boolean() | ncs.categorical(), + (ncs.by_name("a", "b") | ncs.string()), (nwp.col("b", "c") & nwp.col("a")), nwp.col("a", "b").min().over("c", order_by="e"), - (~ndcs.by_dtype(nw.Int64()) - ndcs.datetime()), + (~ncs.by_dtype(nw.Int64()) - ncs.datetime()), nwp.nth(6, 2).abs().cast(nw.Int32()) + 10, *MULTI_OUTPUT_EXPRS, ], ) -def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType]) -> None: - irs = parse_into_seq_of_expr_ir(expr.alias("dupe")) +def test_prepare_projection_duplicate(expr: nwp.Expr, df_1: Frame) -> None: pattern = re.compile(r"\.alias\(.dupe.\)") + expr = expr.alias("dupe") with pytest.raises(DuplicateError, match=pattern): - prepare_projection(irs, schema=schema_1) + df_1.project(expr) @pytest.mark.parametrize( @@ -532,14 +547,11 @@ def test_prepare_projection_duplicate(expr: nwp.Expr, schema_1: dict[str, DType] ], ) def test_prepare_projection_column_not_found( - into_exprs: IntoExpr | Sequence[IntoExpr], - missing: Sequence[str], - schema_1: dict[str, DType], + into_exprs: IntoExpr | Sequence[IntoExpr], missing: Sequence[str], df_1: Frame ) -> None: pattern = re.compile(rf"not found: {re.escape(repr(missing))}") - irs = parse_into_seq_of_expr_ir(into_exprs) with pytest.raises(ColumnNotFoundError, match=pattern): - prepare_projection(irs, schema=schema_1) + df_1.project(into_exprs) @pytest.mark.parametrize( @@ -570,19 +582,16 @@ def test_prepare_projection_column_not_found( def test_prepare_projection_horizontal_alias( into_exprs: IntoExpr | Iterable[IntoExpr], function: Callable[..., nwp.Expr], - schema_1: dict[str, DType], + df_1: Frame, ) -> None: # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2572#discussion_r2139965411 - expr = function(into_exprs) - alias_1 = expr.alias("alias(x1)") - irs = parse_into_seq_of_expr_ir(alias_1) - out_irs, _ = prepare_projection(irs, schema=schema_1) + alias_1 = function(into_exprs).alias("alias(x1)") + out_irs = df_1.project(alias_1) assert len(out_irs) == 1 assert out_irs[0] == named_ir("alias(x1)", function("a", "b", "c")) alias_2 = alias_1.alias("alias(x2)") - irs = parse_into_seq_of_expr_ir(alias_2) - out_irs, _ = prepare_projection(irs, schema=schema_1) + out_irs = df_1.project(alias_2) assert len(out_irs) == 1 assert out_irs[0] == named_ir("alias(x2)", function("a", "b", "c")) @@ -591,9 +600,142 @@ def test_prepare_projection_horizontal_alias( "into_exprs", [nwp.nth(-21), nwp.nth(-1, 2, 54, 0), nwp.nth(20), nwp.nth([-10, -100])] ) def test_prepare_projection_index_error( - into_exprs: IntoExpr | Iterable[IntoExpr], schema_1: dict[str, DType] + into_exprs: IntoExpr | Iterable[IntoExpr], df_1: Frame +) -> None: + with pytest.raises( + ColumnNotFoundError, + match=re_compile( + r"invalid.+column.+index.+schema.+last.+column.+`nth\(\-?\d{2,3}\)`" + ), + ): + df_1.project(into_exprs) + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + ( + nwp.nth(range(3)) * nwp.nth(3, 4, 5).max(), + [ + nwp.col("a") * nwp.col("d").max(), + nwp.col("b") * nwp.col("e").max(), + nwp.col("c") * nwp.col("f").max(), + ], + ), + ( + (10 / nwp.col("e", "d", "b", "a")).name.keep(), + [ + named_ir("e", 10 / nwp.col("e")), + named_ir("d", 10 / nwp.col("d")), + named_ir("b", 10 / nwp.col("b")), + named_ir("a", 10 / nwp.col("a")), + ], + ), + ( + ( + (ncs.categorical() | ncs.string()) + .as_expr() + .cast(nw.String) + .str.len_chars() + .name.map(lambda s: f"len_chars({s!r})") + - ncs.by_dtype(nw.UInt16).as_expr() + ).name.suffix("-col('g')"), + [ + named_ir( + "len_chars('k')-col('g')", + nwp.col("k").cast(nw.String).str.len_chars() - nwp.col("g"), + ), + named_ir( + "len_chars('p')-col('g')", + nwp.col("p").cast(nw.String).str.len_chars() - nwp.col("g"), + ), + ], + ), + ( + (nwp.all().first() == nwp.all().last()).name.suffix("_first_eq_last"), + [ + named_ir("a_first_eq_last", nwp.col("a").first() == nwp.col("a").last()), + named_ir("b_first_eq_last", nwp.col("b").first() == nwp.col("b").last()), + named_ir("c_first_eq_last", nwp.col("c").first() == nwp.col("c").last()), + named_ir("d_first_eq_last", nwp.col("d").first() == nwp.col("d").last()), + named_ir("e_first_eq_last", nwp.col("e").first() == nwp.col("e").last()), + named_ir("f_first_eq_last", nwp.col("f").first() == nwp.col("f").last()), + named_ir("g_first_eq_last", nwp.col("g").first() == nwp.col("g").last()), + named_ir("h_first_eq_last", nwp.col("h").first() == nwp.col("h").last()), + named_ir("i_first_eq_last", nwp.col("i").first() == nwp.col("i").last()), + named_ir("j_first_eq_last", nwp.col("j").first() == nwp.col("j").last()), + named_ir("k_first_eq_last", nwp.col("k").first() == nwp.col("k").last()), + named_ir("l_first_eq_last", nwp.col("l").first() == nwp.col("l").last()), + named_ir("m_first_eq_last", nwp.col("m").first() == nwp.col("m").last()), + named_ir("n_first_eq_last", nwp.col("n").first() == nwp.col("n").last()), + named_ir("o_first_eq_last", nwp.col("o").first() == nwp.col("o").last()), + named_ir("p_first_eq_last", nwp.col("p").first() == nwp.col("p").last()), + named_ir("q_first_eq_last", nwp.col("q").first() == nwp.col("q").last()), + named_ir("r_first_eq_last", nwp.col("r").first() == nwp.col("r").last()), + named_ir("s_first_eq_last", nwp.col("s").first() == nwp.col("s").last()), + named_ir("u_first_eq_last", nwp.col("u").first() == nwp.col("u").last()), + ], + ), + ], + ids=["3:3", "1:4", "2:1", "All:All"], +) +def test_expand_binary_expr_combination( + df_1: Frame, expr: nwp.Expr, expected: Iterable[ir.NamedIR | nwp.Expr] ) -> None: - irs = parse_into_seq_of_expr_ir(into_exprs) - pattern = re.compile(r"invalid.+index.+nth", re.DOTALL | re.IGNORECASE) - with pytest.raises(ComputeError, match=pattern): - prepare_projection(irs, schema=schema_1) + actuals = df_1.project(expr) + for actual, expect in zip_strict(actuals, expected): + assert_expr_ir_equal(actual, expect) + + +def test_expand_binary_expr_combination_invalid(df_1: Frame) -> None: + # fmt: off + expr = re.escape( + "ncs.all() + ncs.by_name('b', 'c', require_all=True)\n" + "^^^^^^^^^" + ) + # fmt: on + shapes = "(20 != 2)" + pattern = rf"{shapes}.+\n{expr}" + all_to_two = nwp.all() + nwp.col("b", "c") + with pytest.raises(MultiOutputExpressionError, match=pattern): + df_1.project(all_to_two) + + expr = re.escape( + "ncs.by_name('a', 'b', require_all=True).abs().fill_null([lit(int: 0)]).round() * ncs.by_index([9, 10, 11], require_all=True).cast(Int64).sort(asc)\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" + ) + shapes = "(2 != 3)" + pattern = rf"{shapes}.+\n{expr}" + two_to_three = ( + nwp.col("a", "b").abs().fill_null(0).round(2) + * nwp.nth(9, 10, 11).cast(nw.Int64).sort() + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + df_1.project(two_to_three) + + # fmt: off + expr = re.escape( + "ncs.numeric() / [(ncs.numeric()) - (ncs.by_dtype([Int64]))]\n" + "^^^^^^^^^^^^^" + ) + # fmt: on + shapes = "(10 != 9)" + pattern = rf"{shapes}.+\n{expr}" + ten_to_nine = ( + ncs.numeric().as_expr() / (ncs.numeric() - ncs.by_dtype(nw.Int64)).as_expr() + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + df_1.project(ten_to_nine) + + +def test_over_order_by_names() -> None: + expr = nwp.col("a").first().over(order_by=ncs.string()) + e_ir = expr._ir + assert isinstance(e_ir, ir.OrderedWindowExpr) + with pytest.raises( + InvalidOperationError, + match=re_compile( + r"cannot use.+order_by_names.+before.+expansion.+ncs.string\(\)" + ), + ): + list(e_ir.order_by_names()) diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index e556352268..fd3b51f5e5 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -15,6 +15,7 @@ from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expressions import functions as F, operators as ops from narwhals._plan.expressions.literal import SeriesLiteral +from narwhals._plan.expressions.ranges import IntRange from narwhals.exceptions import ( ComputeError, InvalidIntoExprError, @@ -23,7 +24,7 @@ MultiOutputExpressionError, ShapeError, ) -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, re_compile if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -128,7 +129,7 @@ def test_valid_windows() -> None: assert nwp.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") -def test_invalid_repeat_agg() -> None: +def test_repeat_agg_invalid() -> None: with pytest.raises(InvalidOperationError): nwp.col("a").mean().mean() with pytest.raises(InvalidOperationError): @@ -145,7 +146,7 @@ def test_invalid_repeat_agg() -> None: # NOTE: Previously multiple different errors, but they can be reduced to the same thing # Once we are scalar, only elementwise is allowed -def test_invalid_agg_non_elementwise() -> None: +def test_agg_non_elementwise_invalid() -> None: pattern = re.compile(r"cannot use.+rank.+aggregated.+mean", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").mean().rank() @@ -168,7 +169,7 @@ def test_agg_non_elementwise_range_special() -> None: assert isinstance(e_ir.expr.input[1], ir.Len) -def test_invalid_int_range() -> None: +def test_int_range_invalid() -> None: pattern = re.compile(r"scalar.+agg", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.int_range(nwp.col("a")) @@ -178,16 +179,37 @@ def test_invalid_int_range() -> None: nwp.int_range(0, nwp.col("a").abs()) with pytest.raises(InvalidOperationError, match=pattern): nwp.int_range(nwp.col("a") + 1) + with pytest.raises(InvalidOperationError, match=pattern): + nwp.int_range((1 + nwp.col("b")).name.keep()) + int_range = IntRange(step=1, dtype=nw.Int64()) + with pytest.raises(InvalidOperationError, match=r"at least 2 inputs.+int_range"): + int_range.to_function_expr(ir.col("a")) + + +@pytest.mark.xfail( + reason="Not implemented `int_range(eager=True)`", raises=NotImplementedError +) +def test_int_range_series() -> None: + assert isinstance(nwp.int_range(50, eager=True), nwp.Series) + +def test_over_invalid() -> None: + with pytest.raises(TypeError, match=r"one of.+partition_by.+or.+order_by"): + nwp.col("a").last().over() -# NOTE: Non-`polars`` rule -def test_invalid_over() -> None: + # NOTE: Non-`polars` rule pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").fill_null(3).over("b") + # NOTE: This version isn't elementwise + expr_ir = nwp.col("a").fill_null(strategy="backward").over("b")._ir + assert isinstance(expr_ir, ir.WindowExpr) + assert isinstance(expr_ir.expr, ir.FunctionExpr) + assert isinstance(expr_ir.expr.function, F.FillNullWithStrategy) -def test_nested_over() -> None: + +def test_over_nested() -> None: pattern = re.compile(r"cannot nest.+over", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").mean().over("b").over("c") @@ -197,7 +219,7 @@ def test_nested_over() -> None: # NOTE: This *can* error in polars, but only if the length **actually changes** # The rule then breaks down to needing the same length arrays in all parts of the over -def test_filtration_over() -> None: +def test_over_filtration() -> None: pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").drop_nulls().over("b") @@ -207,29 +229,9 @@ def test_filtration_over() -> None: nwp.col("a").diff().drop_nulls().over("b", order_by="i") -def test_invalid_binary_expr_multi() -> None: - pattern = re.escape("all() + cols(['b', 'c'])\n ^^^^^^^^^^^^^^^^") - with pytest.raises(MultiOutputExpressionError, match=pattern): - nwp.all() + nwp.col("b", "c") - pattern = re.escape( - "index_columns((1, 2, 3)) * index_columns((4, 5, 6)).max()\n" - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" - ) - with pytest.raises(MultiOutputExpressionError, match=pattern): - nwp.nth(1, 2, 3) * nwp.nth(4, 5, 6).max() - pattern = re.escape( - "cols(['a', 'b', 'c']).abs().fill_null([lit(int: 0)]).round() * index_columns((9, 10)).cast(Int64).sort(asc)\n" - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" - ) - with pytest.raises(MultiOutputExpressionError, match=pattern): - nwp.col("a", "b", "c").abs().fill_null(0).round(2) * nwp.nth(9, 10).cast( - nw.Int64() - ).sort() - - -def test_invalid_binary_expr_length_changing() -> None: +def test_binary_expr_length_changing_invalid() -> None: a = nwp.col("a") - b = nwp.col("b") + b = nwp.col("b").exp() with pytest.raises(LengthChangingExprError): a.unique() + b.unique() @@ -268,7 +270,7 @@ def test_binary_expr_length_changing_agg() -> None: ) -def test_invalid_binary_expr_shape() -> None: +def test_binary_expr_shape_invalid() -> None: pattern = re.compile( re.escape("Cannot combine length-changing expressions with length-preserving"), re.IGNORECASE, @@ -282,6 +284,8 @@ def test_invalid_binary_expr_shape() -> None: a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) with pytest.raises(ShapeError, match=pattern): a / b.drop_nulls() + with pytest.raises(ShapeError, match=pattern): + a.fill_null(1) // b.rolling_mean(5) @pytest.mark.parametrize("into_iter", [list, tuple, deque, iter, dict.fromkeys, set]) @@ -327,7 +331,7 @@ def test_is_in_series() -> None: ), ], ) -def test_invalid_is_in(other: Any, context: AbstractContextManager[Any]) -> None: +def test_is_in_invalid(other: Any, context: AbstractContextManager[Any]) -> None: with context: nwp.col("a").is_in(other) @@ -529,3 +533,80 @@ def test_hist_invalid() -> None: a.hist(deque((3, 2, 1))) with pytest.raises(TypeError): a.hist(1) # type: ignore[arg-type] + + +def test_into_expr_invalid() -> None: + pytest.importorskip("polars") + import polars as pl + + with pytest.raises( + TypeError, match=re_compile(r"expected.+narwhals.+got.+polars.+hint") + ): + nwp.col("a").max().over(pl.col("b")) # type: ignore[arg-type] + + +def test_when_invalid() -> None: + pattern = re_compile(r"multi-output expr.+not supported in.+when.+context") + + when = nwp.when(nwp.col("a", "b", "c").is_finite()) + when_then = when.then(nwp.col("d").is_unique()) + when_then_when = when_then.when( + (nwp.median("a", "b", "c") > 2) | nwp.col("d").is_nan() + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when.then(nwp.max("c", "d")) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when_then.otherwise(nwp.min("h", "i", "j")) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when_then_when.then(nwp.col(["b", "y", "e"])) + + +# NOTE: `Then`, `ChainedThen` use multi-inheritance, but **need** to use `Expr.__eq__` +def test_then_equal() -> None: + expr = nwp.col("a").clip(nwp.col("a").kurtosis(), nwp.col("a").log()) + other = "other" + then = nwp.when(a="b").then(nwp.col("c").skew()) + chained_then = then.when("d").then("e") + + assert isinstance(then == expr, nwp.Expr) + assert isinstance(then == other, nwp.Expr) + + assert isinstance(chained_then == expr, nwp.Expr) + assert isinstance(chained_then == other, nwp.Expr) + + assert isinstance(then == chained_then, nwp.Expr) + + +def test_dt_timestamp_invalid() -> None: + assert nwp.col("a").dt.timestamp() + with pytest.raises( + TypeError, match=re_compile(r"invalid.+time_unit.+expected.+got 's'") + ): + nwp.col("a").dt.timestamp("s") + + +def test_dt_truncate_invalid() -> None: + assert nwp.col("a").dt.truncate("1d") + with pytest.raises(ValueError, match=re_compile(r"invalid.+every.+abcd")): + nwp.col("a").dt.truncate("abcd") + + +def test_replace_strict() -> None: + a = nwp.col("a") + remapping = a.replace_strict({1: 3, 2: 4}, return_dtype=nw.Int8) + sequences = a.replace_strict(old=[1, 2], new=[3, 4], return_dtype=nw.Int8()) + assert_expr_ir_equal(remapping, sequences) + + +def test_replace_strict_invalid() -> None: + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + nwp.col("a").replace_strict("b") + + with pytest.raises( + TypeError, + match="`new` argument cannot be used if `old` argument is a Mapping type", + ): + nwp.col("a").replace_strict(old={1: 2, 3: 4}, new=[5, 6, 7]) diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 455fecd114..fdcdb566c5 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -6,7 +6,7 @@ import narwhals as nw from narwhals import _plan as nwp -from narwhals._plan import _parse, expressions as ir, selectors as ndcs +from narwhals._plan import _parse, expressions as ir, selectors as ncs from narwhals._plan._rewrites import ( rewrite_all, rewrite_binary_agg_over, @@ -71,7 +71,8 @@ def test_rewrite_elementwise_over_multiple(schema_2: dict[str, DType]) -> None: nwp.col("b", "c").last().replace_strict({1: 2}), "d" ).to_narwhals() assert_expr_ir_equal( - before, "cols(['b', 'c']).last().replace_strict().over([col('d')])" + before, + "ncs.by_name('b', 'c', require_all=True).last().replace_strict().over([col('d')])", ) actual = rewrite_all(before, schema=schema_2, rewrites=[rewrite_elementwise_over]) assert len(actual) == 2 @@ -101,7 +102,7 @@ def test_rewrite_elementwise_over_complex(schema_2: dict[str, DType]) -> None: .alias("x2") ), ~(nwp.col("d").is_duplicated().alias("d*")).alias("d**").over("b"), - ndcs.string().str.contains("some").name.suffix("_some"), + ncs.string().str.contains("some").name.suffix("_some"), ( _to_window_expr(nwp.nth(3, 4, 1).null_count().sqrt(), "f", "g", "j") .to_narwhals() diff --git a/tests/plan/frame_partition_by_test.py b/tests/plan/frame_partition_by_test.py new file mode 100644 index 0000000000..429a78fb20 --- /dev/null +++ b/tests/plan/frame_partition_by_test.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +from narwhals._plan import Selector, selectors as ncs +from narwhals._utils import zip_strict +from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError +from tests.plan.utils import assert_equal_data, dataframe, re_compile + +if TYPE_CHECKING: + from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return { + "a": ["a", "b", "a", None, "b", "c"], + "b": [1, 2, 1, 5, 3, 3], + "c": [5, 4, 3, 6, 2, 1], + } + + +@pytest.mark.parametrize( + ("include_key", "expected"), + [ + ( + True, + [ + {"a": ["a", "a"], "b": [1, 1], "c": [5, 3]}, + {"a": ["b", "b"], "b": [2, 3], "c": [4, 2]}, + {"a": [None], "b": [5], "c": [6]}, + {"a": ["c"], "b": [3], "c": [1]}, + ], + ), + ( + False, + [ + {"b": [1, 1], "c": [5, 3]}, + {"b": [2, 3], "c": [4, 2]}, + {"b": [5], "c": [6]}, + {"b": [3], "c": [1]}, + ], + ), + ], + ids=["include_key", "exclude_key"], +) +@pytest.mark.parametrize( + "by", + ["a", ncs.string(), ncs.matches("a"), ncs.by_name("a"), ncs.by_dtype(nw.String)], + ids=["str", "ncs.string", "ncs.matches", "ncs.by_name", "ncs.by_dtype"], +) +def test_partition_by_single( + data: Data, by: ColumnNameOrSelector, *, include_key: bool, expected: Any +) -> None: + df = dataframe(data) + results = df.partition_by(by, include_key=include_key) + for df, expect in zip_strict(results, expected): + assert_equal_data(df, expect) + + +@pytest.mark.parametrize( + ("include_key", "expected"), + [ + ( + True, + [ + {"a": ["a", "a"], "b": [1, 1], "c": [5, 3]}, + {"a": ["b"], "b": [2], "c": [4]}, + {"a": [None], "b": [5], "c": [6]}, + {"a": ["b"], "b": [3], "c": [2]}, + {"a": ["c"], "b": [3], "c": [1]}, + ], + ), + (False, [{"c": [5, 3]}, {"c": [4]}, {"c": [6]}, {"c": [2]}, {"c": [1]}]), + ], + ids=["include_key", "exclude_key"], +) +@pytest.mark.parametrize( + ("by", "more_by"), + [ + ("a", "b"), + (["a", "b"], ()), + (ncs.matches("a|b"), ()), + (ncs.string(), "b"), + (ncs.by_name("a", "b"), ()), + (ncs.by_name("b"), ncs.by_name("a")), + (ncs.by_dtype(nw.String) | (ncs.numeric() - ncs.by_name("c")), []), + ], + ids=[ + "str-variadic", + "str-list", + "ncs.matches", + "ncs.string-str", + "ncs.by_name", + "2x-selector", + "BinarySelector", + ], +) +def test_partition_by_multiple( + data: Data, + by: ColumnNameOrSelector, + more_by: OneOrIterable[ColumnNameOrSelector], + *, + include_key: bool, + expected: Any, +) -> None: + df = dataframe(data) + if isinstance(more_by, (str, Selector)): + results = df.partition_by(by, more_by, include_key=include_key) + else: + results = df.partition_by(by, *more_by, include_key=include_key) + for df, expect in zip_strict(results, expected): + assert_equal_data(df, expect) + + +def test_partition_by_missing_names(data: Data) -> None: + df = dataframe(data) + with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['d']")): + df.partition_by("d") + with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['e']")): + df.partition_by("c", "e") + + +def test_partition_by_duplicate_names(data: Data) -> None: + df = dataframe(data) + with pytest.raises(DuplicateError, match=re_compile(r"expected.+unique.+got.+'c'")): + df.partition_by("c", ncs.numeric()) + + +def test_partition_by_fully_empty_selector(data: Data) -> None: + df = dataframe(data) + with pytest.raises( + ComputeError, match=r"at least one key is required in a group_by operation" + ): + df.partition_by(ncs.array(ncs.numeric()), ncs.struct(), ncs.duration()) + + +# NOTE: Matching polars behavior +def test_partition_by_partially_missing_selector(data: Data) -> None: + df = dataframe(data) + results = df.partition_by(ncs.string() | ncs.list() | ncs.enum()) + expected = nw.Schema({"a": nw.String(), "b": nw.Int64(), "c": nw.Int64()}) + for df in results: + assert df.schema == expected diff --git a/tests/plan/group_by_test.py b/tests/plan/group_by_test.py index e1e60f3605..b7f5035fe2 100644 --- a/tests/plan/group_by_test.py +++ b/tests/plan/group_by_test.py @@ -7,7 +7,7 @@ import narwhals as nw from narwhals import _plan as nwp -from narwhals._plan import selectors as npcs +from narwhals._plan import selectors as ncs from narwhals.exceptions import InvalidOperationError from tests.plan.utils import assert_equal_data, dataframe from tests.utils import PYARROW_VERSION, assert_equal_data as _assert_equal_data @@ -20,7 +20,8 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence - from narwhals._plan.typing import IntoExpr + from narwhals._plan.typing import IntoExpr, OneOrIterable + from tests.conftest import Data def test_group_by_iter() -> None: @@ -358,16 +359,16 @@ def test_fancy_functions() -> None: result = df.group_by("a").agg(nwp.all().std(ddof=0)).sort("a") expected = {"a": [1, 2], "b": [0.5, 0.0]} assert_equal_data(result, expected) - result = df.group_by("a").agg(npcs.numeric().std(ddof=0)).sort("a") + result = df.group_by("a").agg(ncs.numeric().std(ddof=0)).sort("a") assert_equal_data(result, expected) - result = df.group_by("a").agg(npcs.matches("b").std(ddof=0)).sort("a") + result = df.group_by("a").agg(ncs.matches("b").std(ddof=0)).sort("a") assert_equal_data(result, expected) - result = df.group_by("a").agg(npcs.matches("b").std(ddof=0).alias("c")).sort("a") + result = df.group_by("a").agg(ncs.matches("b").std(ddof=0).alias("c")).sort("a") expected = {"a": [1, 2], "c": [0.5, 0.0]} assert_equal_data(result, expected) result = ( df.group_by("a") - .agg(npcs.matches("b").std(ddof=0).name.map(lambda _x: "c")) + .agg(ncs.matches("b").std(ddof=0).name.map(lambda _x: "c")) .sort("a") ) assert_equal_data(result, expected) @@ -406,9 +407,9 @@ def test_fancy_functions() -> None: {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, ["y"], ), - ( - [npcs.by_dtype(nw.Float64()).abs()], - [npcs.numeric().sum()], + pytest.param( + [ncs.by_dtype(nw.Float64()).abs()], + [ncs.numeric().sum()], {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, ["y"], ), @@ -454,7 +455,7 @@ def test_group_by_selector() -> None: } result = ( dataframe(data) - .group_by(npcs.by_dtype(nw.Int64), "c") + .group_by(ncs.by_dtype(nw.Int64), "c") .agg(nwp.col("x").mean()) .sort("a", "b") ) @@ -576,9 +577,9 @@ def test_group_by_agg_last( "d": [["three", "one"], ["three"], ["one"]], }, ), - ( + pytest.param( ["d", "c"], - [npcs.string().unique(), nwp.col("b").first().alias("b_first")], + [ncs.string().unique(), nwp.col("b").first().alias("b_first")], { "d": ["one", "one", "three", "three", "three"], "c": [1, 3, 2, 4, 5], @@ -698,10 +699,10 @@ def test_group_by_exclude_keys() -> None: "m": [0, 1, 2], } df = dataframe(data).with_columns( - npcs.boolean().fill_null(False), npcs.numeric().fill_null(0) + ncs.boolean().fill_null(False), ncs.numeric().fill_null(0) ) exclude = "b", "c", "d", "e", "f", "g", "j", "k", "l", "m" - result = df.group_by(nwp.exclude(exclude)).agg(npcs.all().sum()).sort("a", "h") + result = df.group_by(nwp.exclude(exclude)).agg(nwp.all().sum()).sort("a", "h") expected = { "a": ["A", "A", "B"], "h": [False, True, False], @@ -717,3 +718,45 @@ def test_group_by_exclude_keys() -> None: "m": [0, 2, 1], } assert_equal_data(result, expected) + + +IGNORE_KEYS: Data = {"a": [1, 2], "b_sum": [9, 6]} +EXPAND_KEYS: Data = {"a": [1, 2], "a_sum": [2, 2], "b_sum": [9, 6]} + + +@pytest.mark.parametrize( + ("aggs", "expected"), + [ + (nwp.all().sum().name.suffix("_sum"), IGNORE_KEYS), + (ncs.all().sum().name.suffix("_sum"), IGNORE_KEYS), + (ncs.matches(r"a|b").sum().name.suffix("_sum"), IGNORE_KEYS), + (ncs.integer().sum().name.suffix("_sum"), IGNORE_KEYS), + (nwp.col("a", "b").sum().name.suffix("_sum"), EXPAND_KEYS), + (nwp.nth(0, 1).sum().name.suffix("_sum"), EXPAND_KEYS), + ( + [nwp.nth(0).sum().alias("a_sum"), ncs.last().sum().name.suffix("_sum")], + EXPAND_KEYS, + ), + ( + [nwp.col("a").sum().name.suffix("_sum"), nwp.col("b").sum().alias("b_sum")], + EXPAND_KEYS, + ), + ], + ids=[ + "nw.All", + "cs.All", + "Matches", + "Integer", + "ByName", + "ByIndex", + "ByIndex-2", + "Column-2", + ], +) +def test_group_by_consistent_exclude_21773( + aggs: OneOrIterable[IntoExpr], expected: Data +) -> None: + # NOTE: See https://github.com/pola-rs/polars/issues/21773 + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.group_by("a").agg(aggs) + assert_equal_data(result, expected) diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 8e60759d0b..7a286520e5 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -212,3 +212,25 @@ def test_immutable___slots___(immutable_type: type[Immutable]) -> None: slots = immutable_type.__slots__ if slots: assert len(slots) != 0, slots + + +def test_immutable_str() -> None: + class MixedFields(Immutable): + __slots__ = ("name", "unique_id", "aliases") # noqa: RUF023 + name: str + unique_id: int + aliases: tuple[str, str, str] + + class Parent(Immutable): + __slots__ = ("children",) + children: tuple[MixedFields, ...] + + bob = MixedFields(name="bob", unique_id=123, aliases=("robert", "bobert", "Bob")) + parent = Parent(children=(bob,)) + + expected_child = ( + "MixedFields(name='bob', unique_id=123, aliases=['robert', 'bobert', 'Bob'])" + ) + expected_parent = f"Parent(children=[{expected_child}])" + assert str(bob) == expected_child + assert str(parent) == expected_parent diff --git a/tests/plan/meta_test.py b/tests/plan/meta_test.py index 2b5ca80c35..f7738dbf5e 100644 --- a/tests/plan/meta_test.py +++ b/tests/plan/meta_test.py @@ -1,8 +1,17 @@ from __future__ import annotations +import datetime as dt +import re +import string +from typing import Any + import pytest +import narwhals as nw +import narwhals._plan.selectors as ncs from narwhals import _plan as nwp +from narwhals.exceptions import ComputeError +from tests.plan.utils import series from tests.utils import POLARS_VERSION pytest.importorskip("polars") @@ -23,6 +32,11 @@ LEN_CASE = (nwp.len().alias("count"), pl.count(), "count") +XFAIL_LITERAL_LIST = pytest.mark.xfail( + reason="'list' is not supported in `nw.lit`", raises=TypeError +) + + @pytest.mark.parametrize( ("nw_expr", "pl_expr", "expected"), [ @@ -44,6 +58,11 @@ ), (nwp.all().mean(), pl.all().mean(), []), (nwp.all().mean().sort_by("d"), pl.all().mean().sort_by("d"), ["d"]), + ( + nwp.all_horizontal(*string.ascii_letters), + pl.all_horizontal(*string.ascii_letters), + list(string.ascii_letters), + ), ], ) def test_meta_root_names( @@ -181,3 +200,149 @@ def test_meta_output_name(nw_expr: nwp.Expr, pl_expr: pl.Expr, expected: str) -> nw_result = nw_expr.meta.output_name() assert nw_result == expected assert nw_result == pl_result + + +def test_root_and_output_names() -> None: + e = nwp.col("foo") * nwp.col("bar") + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = nwp.col("foo").filter(bar=13) + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = nwp.sum("foo").over("groups") + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "groups"] + + e = nwp.sum("foo").is_between(nwp.len() - 10, nwp.col("bar")) + assert e.meta.output_name() == "foo" + assert e.meta.root_names() == ["foo", "bar"] + + e = nwp.len() + assert e.meta.output_name() == "len" + + with pytest.raises( + ComputeError, + match=re.escape( + "unable to find root column name for expr 'ncs.all()' when calling 'output_name'" + ), + ): + nwp.all().name.suffix("_").meta.output_name() + + assert ( + nwp.all().name.suffix("_").meta.output_name(raise_if_undetermined=False) is None + ) + + +def test_meta_has_multiple_outputs() -> None: + e = nwp.col(["a", "b"]).name.suffix("_foo") + assert e.meta.has_multiple_outputs() + + +def test_is_column() -> None: + e = nwp.col("foo") + assert e.meta.is_column() + + e = nwp.col("foo").alias("bar") + assert not e.meta.is_column() + + e = nwp.col("foo") * nwp.col("bar") + assert not e.meta.is_column() + + +@pytest.mark.parametrize( + ("expr", "is_column_selection"), + [ + # columns + (nwp.col("foo"), True), + (nwp.col("foo", "bar"), True), + # column expressions + (nwp.col("foo") + 100, False), + (nwp.col("foo").__floordiv__(10), False), + (nwp.col("foo") * nwp.col("bar"), False), + # selectors / expressions + (ncs.numeric() * 100, False), + (ncs.temporal() - ncs.by_dtype(nw.Time), True), + (ncs.numeric().exclude("value"), True), + ((ncs.temporal() - ncs.by_dtype(nw.Time())).exclude("dt"), True), + # top-level selection funcs + (nwp.nth(2), True), + (ncs.first(), True), + (ncs.last(), True), + ], +) +def test_is_column_selection(expr: nwp.Expr, *, is_column_selection: bool) -> None: + if is_column_selection: + assert expr.meta.is_column_selection() + assert expr.meta.is_column_selection(allow_aliasing=True) + expr = ( + expr.name.suffix("!") if expr.meta.has_multiple_outputs() else expr.alias("!") + ) + assert not expr.meta.is_column_selection() + assert expr.meta.is_column_selection(allow_aliasing=True) + else: + assert not expr.meta.is_column_selection() + + +@pytest.mark.parametrize( + "value", + [ + None, + 1234, + 567.89, + float("inf"), + dt.date(2000, 1, 1), + dt.datetime(1974, 1, 1, 12, 45, 1), + dt.time(10, 30, 45), + dt.timedelta(hours=-24), + pytest.param(["x", "y", "z"], marks=XFAIL_LITERAL_LIST), + series([None, None]), + pytest.param([[10, 20], [30, 40]], marks=XFAIL_LITERAL_LIST), + "this is the way", + ], +) +def test_is_literal(value: Any) -> None: + e = nwp.lit(value) + assert e.meta.is_literal() + + e = nwp.lit(value).alias("foo") + assert not e.meta.is_literal() + + e = nwp.lit(value).alias("foo") + assert e.meta.is_literal(allow_aliasing=True) + + +def test_literal_output_name() -> None: + e = nwp.lit(1) + data = 1, 2, 3 + assert e.meta.output_name() == "literal" + + e = nwp.lit(series(data).alias("abc")) + assert e.meta.output_name() == "abc" + + e = nwp.lit(series(data)) + assert e.meta.output_name() == "" + + +# NOTE: Very low-priority +@pytest.mark.xfail( + reason="TODO: `Expr.struct.field` influences `meta.output_name`.", + raises=AssertionError, +) +def test_struct_field_output_name_24003() -> None: + assert nwp.col("ball").struct.field("radius").meta.output_name() == "radius" + + +def test_selector_by_name_single() -> None: + assert ncs.by_name("foo").meta.output_name() == "foo" + + +def test_selector_by_name_multiple() -> None: + with pytest.raises( + ComputeError, + match=re.escape( + "unable to find root column name for expr 'ncs.by_name('foo', 'bar', require_all=True)' when calling 'output_name'" + ), + ): + ncs.by_name(["foo", "bar"]).meta.output_name() diff --git a/tests/plan/over_test.py b/tests/plan/over_test.py new file mode 100644 index 0000000000..c561b8d47e --- /dev/null +++ b/tests/plan/over_test.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +pytest.importorskip("pyarrow") + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan import selectors as ncs +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe, re_compile + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExprColumn, OneOrIterable + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + } + + +@pytest.fixture +def data_with_null(data: Data) -> Data: + return data | {"b": [1, 2, None, 5, 3]} + + +@pytest.fixture +def data_alt() -> Data: + return {"a": [3, 5, 1, 2, None], "b": [0, 1, 3, 2, 1], "c": [9, 1, 2, 1, 1]} + + +@pytest.mark.parametrize( + "partition_by", + [ + "a", + ["a"], + nwp.nth(0), + ncs.first(), + ncs.string(), + ncs.by_dtype(nw.String), + ncs.by_name("a"), + ncs.matches(r"a"), + ncs.all() - ncs.numeric(), + ], +) +def test_over_single(data: Data, partition_by: OneOrIterable[IntoExprColumn]) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_max": [5, 5, 3, 3, 3], + } + result = ( + dataframe(data) + .with_columns(c_max=nwp.col("c").max().over(partition_by)) + .sort("i") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + "partition_by", + [ + ("a", "b"), + [nwp.col("a"), nwp.col("b")], + [nwp.nth(0), nwp.nth(1)], + nwp.col("a", "b"), + nwp.nth(0, 1), + ncs.by_name("a", "b"), + ncs.matches(r"a|b"), + ncs.all() - ncs.by_name(["c", "i"]), + ], + ids=[ + "tuple[str]", + "col-col", + "nth-nth", + "cols", + "index_columns", + "by_name", + "matches", + "binary_selector", + ], +) +def test_over_multiple(data: Data, partition_by: OneOrIterable[IntoExprColumn]) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_min": [5, 4, 1, 2, 1], + } + result = ( + dataframe(data) + .with_columns(c_min=nwp.col("c").min().over(partition_by)) + .sort("i") + ) + assert_equal_data(result, expected) + + +# NOTE: Not planned +@pytest.mark.xfail( + reason="Native `pyarrow` `group_by` isn't enough", raises=InvalidOperationError +) +def test_over_cum_sum_partition_by(data_with_null: Data) -> None: # pragma: no cover + df = dataframe(data_with_null) + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, None, 5, 3], + "c": [5, 4, 3, 2, 1], + "b_cum_sum": [1, 3, None, 5, 8], + "c_cum_sum": [5, 9, 3, 5, 6], + } + + result = ( + df.with_columns(nwp.col("b", "c").cum_sum().over("a").name.suffix("_cum_sum")) + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) + + +def test_over_std_var(data: Data) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_std0": [0.5, 0.5, 0.816496580927726, 0.816496580927726, 0.816496580927726], + "c_std1": [0.7071067811865476, 0.7071067811865476, 1.0, 1.0, 1.0], + "c_var0": [ + 0.25, + 0.25, + 0.6666666666666666, + 0.6666666666666666, + 0.6666666666666666, + ], + "c_var1": [0.5, 0.5, 1.0, 1.0, 1.0], + } + + result = ( + dataframe(data) + .with_columns( + c_std0=nwp.col("c").std(ddof=0).over("a"), + c_std1=nwp.col("c").std(ddof=1).over("a"), + c_var0=nwp.col("c").var(ddof=0).over(ncs.string()), + c_var1=nwp.col("c").var(ddof=1).over("a"), + ) + .sort("i") + ) + assert_equal_data(result, expected) + + +# NOTE: Supporting this for pyarrow is new 🥳 +def test_over_anonymous_reduction() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.with_columns(nwp.all().sum().over("a").name.suffix("_sum")).sort("a", "b") + expected = {"a": [1, 1, 2], "b": [4, 5, 6], "a_sum": [2, 2, 2], "b_sum": [9, 9, 6]} + assert_equal_data(result, expected) + + +def test_over_raise_len_change(data: Data) -> None: + df = dataframe(data) + with pytest.raises(InvalidOperationError): + df.select(nwp.col("b").drop_nulls().over("a")) + + +# NOTE: Slightly different error, but same reason for raising +# (expr-ir): InvalidOperationError: `cum_sum()` is not supported in a `group_by` context +# (main): NotImplementedError: Only aggregation or literal operations are supported in grouped `over` context for PyArrow. +# https://github.com/narwhals-dev/narwhals/blob/ecde261d799a711c2e0a7acf11b108bc45035dc9/narwhals/_arrow/expr.py#L116-L118 +def test_unsupported_over(data: Data) -> None: + df = dataframe(data) + with pytest.raises(InvalidOperationError): + df.select(nwp.col("a").shift(1).cum_sum().over("b")) + + +def test_over_without_partition_by() -> None: + df = dataframe({"a": [1, -1, 2], "i": [0, 2, 1]}) + result = ( + df.with_columns(b=nwp.col("a").abs().cum_sum().over(order_by="i")) + .sort("i") + .select("a", "b", "i") + ) + expected = {"a": [1, 2, -1], "b": [1, 3, 4], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +def test_aggregation_over_without_partition_by() -> None: + df = dataframe({"a": [1, -1, 2], "i": [0, 2, 1]}) + result = ( + df.with_columns(b=nwp.col("a").diff().sum().over(order_by="i")) + .sort("i") + .select("a", "b", "i") + ) + expected = {"a": [1, 2, -1], "b": [-2, -2, -2], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +def test_len_over_2369() -> None: + df = dataframe({"a": [1, 2, 4], "b": ["x", "x", "y"]}) + result = df.with_columns(a_len_per_group=nwp.len().over("b")).sort("a") + expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]} + assert_equal_data(result, expected) + + +def test_shift_kitchen_sink(data_alt: Data) -> None: + result = dataframe(data_alt).select( + nwp.nth(1, 2) + .shift(-1) + .over(order_by=ncs.last()) + .sort(nulls_last=True) + .fill_null(100) + * 5 + ) + expected = {"b": [0, 5, 10, 15, 500], "c": [5, 5, 10, 45, 500]} + assert_equal_data(result, expected) + + +def test_over_order_by_expr(data_alt: Data) -> None: + df = dataframe(data_alt) + result = df.select( + nwp.all() + + nwp.all().last().over(order_by=[nwp.nth(1), ncs.first()], descending=True) + ) + expected = {"a": [6, 8, 4, 5, None], "b": [0, 1, 3, 2, 1], "c": [18, 10, 11, 10, 10]} + assert_equal_data(result, expected) + + +def test_over_order_by_expr_invalid(data_alt: Data) -> None: + df = dataframe(data_alt) + with pytest.raises( + InvalidOperationError, + match=re_compile(r"only.+column.+selection.+in.+order_by.+found.+sort"), + ): + df.select(nwp.col("a").first().over(order_by=nwp.col("b").sort())) + + +def test_null_count_over() -> None: + data = { + "a": ["a", "b", None, None, "b", "c"], + "b": [1, 2, 1, 5, 3, 3], + "c": [5, 4, 3, 6, 2, 1], + } + expected = { + "a": ["a", "b", None, None, "b", "c"], + "b": [1, 2, 1, 5, 3, 3], + "c": [5, 4, 3, 6, 2, 1], + "first_null_count_over_b": [1, 0, 1, 1, 0, 0], + } + df = dataframe(data) + result = df.with_columns( + first_null_count_over_b=ncs.first() + .null_count() + .over(ncs.integer() - ncs.by_name("c")) + ) + assert_equal_data(result, expected) diff --git a/tests/plan/repr_test.py b/tests/plan/repr_test.py new file mode 100644 index 0000000000..325ed52621 --- /dev/null +++ b/tests/plan/repr_test.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import narwhals._plan as nwp + + +def test_repr() -> None: + nwp.col("a").meta.as_selector() + expr = nwp.col("a") + selector = expr.meta.as_selector() + + expr_repr_html = expr._repr_html_() + expr_ir_repr_html = expr._ir._repr_html_() + selector_repr_html = selector._repr_html_() + selector_ir_repr_html = selector._ir._repr_html_() + expr_repr = expr.__repr__() + expr_ir_repr = expr._ir.__repr__() + selector_repr = selector.__repr__() + selector_ir_repr = selector._ir.__repr__() + + # In a notebook, both `Expr` and `ExprIR` are displayed the same + assert expr_repr_html == expr_ir_repr_html + # The actual repr (for debugging) has more information + assert expr_repr != expr_repr_html + # Currently, all extra information is *before* the part which matches + assert expr_repr.endswith(expr_repr_html) + # But these guys should not deviate + assert expr_ir_repr == expr_ir_repr_html + # The same invariants should hold for `Selector` and `SelectorIR` + assert selector_repr_html == selector_ir_repr_html + assert selector_repr != selector_repr_html + assert selector_repr.endswith(selector_repr_html) + assert selector_ir_repr == selector_ir_repr_html + # But they must still be visually different from `Expr` and `ExprIR` + assert selector_repr_html != expr_repr_html + assert selector_repr != expr_repr diff --git a/tests/plan/schema_test.py b/tests/plan/schema_test.py new file mode 100644 index 0000000000..19ecedac77 --- /dev/null +++ b/tests/plan/schema_test.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +from narwhals._plan.schema import FrozenSchema, freeze_schema +from tests.plan.utils import dataframe + + +def test_schema() -> None: + mapping = {"a": nw.Int64(), "b": nw.String()} + schema = nw.Schema(mapping) + frozen_schema = freeze_schema(mapping) + + assert frozen_schema.keys() == schema.keys() + assert tuple(frozen_schema.values()) == tuple(schema.values()) + + # NOTE: Would type-check if `Schema.__init__` didn't make liskov unhappy + assert schema == nw.Schema(frozen_schema) # type: ignore[arg-type] + assert mapping == dict(frozen_schema) + + assert frozen_schema == freeze_schema(mapping) + assert frozen_schema == freeze_schema(**mapping) + assert frozen_schema == freeze_schema(a=nw.Int64(), b=nw.String()) + assert frozen_schema == freeze_schema(schema) + assert frozen_schema == freeze_schema(frozen_schema) + assert frozen_schema == freeze_schema(frozen_schema.items()) + + # NOTE: Using `**` unpacking, despite not inheriting from `Mapping` or `dict` + assert frozen_schema == freeze_schema(**frozen_schema) + + # NOTE: Using `HasSchema` + df = dataframe({"a": [1, 2, 3], "b": ["c", "d", "e"]}) + assert frozen_schema == freeze_schema(df) + + # NOTE: In case this all looks *too good* to be true + assert frozen_schema != freeze_schema(**mapping, c=nw.Float64()) + + assert frozen_schema["a"] == schema["a"] + + assert frozen_schema.get("c") is None + assert frozen_schema.get("c", nw.Unknown) is nw.Unknown + assert frozen_schema.get("c", nw.Unknown()) == nw.Unknown() + + assert "b" in frozen_schema + assert "e" not in frozen_schema + + with pytest.raises(TypeError, match="Cannot subclass 'FrozenSchema'"): + + class MutableSchema(FrozenSchema): ... # type: ignore[misc] diff --git a/tests/plan/selectors_test.py b/tests/plan/selectors_test.py new file mode 100644 index 0000000000..56ba3b1251 --- /dev/null +++ b/tests/plan/selectors_test.py @@ -0,0 +1,751 @@ +"""Tests adapted from [upstream]. + +[upstream]: https://github.com/pola-rs/polars/blob/84d66e960e3d462811f0575e0a6e4e78e34c618c/py-polars/tests/unit/test_selectors.py +""" + +from __future__ import annotations + +import operator +import re +from datetime import timezone +from typing import TYPE_CHECKING + +import pytest + +import narwhals as nw +import narwhals.stable.v1 as nw_v1 +from narwhals import _plan as nwp +from narwhals._plan import Selector, selectors as ncs +from narwhals._plan._guards import is_expr, is_selector +from narwhals._utils import zip_strict +from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError +from tests.plan.utils import ( + Frame, + assert_expr_ir_equal, + assert_not_selector, + is_expr_ir_equal, + named_ir, + re_compile, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from narwhals._plan.typing import IntoExpr, OperatorFn + from narwhals.dtypes import DType + + +@pytest.fixture(scope="module") +def schema_nested_1() -> nw.Schema: + return nw.Schema( + { + "a": nw.Int32(), + "b": nw.List(nw.Int32()), + "c": nw.List(nw.UInt32), + "d": nw.Array(nw.Int32(), 3), + "e": nw.List(nw.String), + "f": nw.Struct({"x": nw.Int32()}), + } + ) + + +@pytest.fixture(scope="module") +def schema_nested_2() -> nw.Schema: + return nw.Schema( + { + "a": nw.Int32(), + "b": nw.Array(nw.Int32(), 4), + "c": nw.Array(nw.UInt32(), 4), + "d": nw.Array(nw.Int32, 3), + "e": nw.List(nw.Int32()), + "f": nw.Array(nw.String(), 4), + "g": nw.Struct({"x": nw.Int32()}), + } + ) + + +@pytest.fixture(scope="module") +def schema_non_nested() -> nw.Schema: + return nw.Schema( + { + "abc": nw.UInt16(), + "bbb": nw.UInt32(), + "cde": nw.Float64(), + "def": nw.Float32(), + "eee": nw.Boolean(), + "fgg": nw.Boolean(), + "ghi": nw.Time(), + "JJK": nw.Date(), + "Lmn": nw.Duration(), + "opp": nw.Datetime("ms"), + "qqR": nw.String(), + } + ) + + +@pytest.fixture(scope="module") +def schema_mixed() -> nw.Schema: + return nw.Schema( + { + "a": nw.Int64(), + "b": nw.Int32(), + "c": nw.Int16(), + "d": nw.Int8(), + "e": nw.UInt64(), + "f": nw.UInt32(), + "g": nw.UInt16(), + "h": nw.UInt8(), + "i": nw.Float64(), + "j": nw.Float32(), + "k": nw.String(), + "l": nw.Datetime(), + "m": nw.Boolean(), + "n": nw.Date(), + "o": nw.Datetime(), + "p": nw.Categorical(), + "q": nw.Duration(), + "r": nw.Enum(["A", "B", "C"]), + "s": nw.List(nw.String()), + "u": nw.Struct({"a": nw.Int64(), "k": nw.String}), + } + ) + + +@pytest.fixture(scope="module") +def df_datetime() -> Frame: + return Frame.from_mapping( + { + "d1": nw.Datetime("ns", "Asia/Tokyo"), + "d2": nw.Datetime("ns", "UTC"), + "d3": nw.Datetime("us", "UTC"), + "d4": nw.Datetime("us"), + "d5": nw.Datetime("ms"), + } + ) + + +def test_selector_all(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + df.assert_selects(ncs.all(), *df.columns) + df.assert_selects(~ncs.all()) + df.assert_selects(~(~ncs.all()), *df.columns) + + selector_and_col = ncs.all() & nwp.col("abc") + df.assert_selects(selector_and_col, "abc") + + +def test_selector_by_dtype(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + selector = ncs.boolean() | ncs.by_dtype(nw.UInt16) + df.assert_selects(selector, "abc", "eee", "fgg") + + selector = ~( + ncs.integer() | ncs.by_dtype(nw.Date(), nw.Datetime, nw.Duration, nw.Time()) + ) + df.assert_selects(selector, "cde", "def", "eee", "fgg", "qqR") + + selector = ncs.by_dtype(nw.Datetime("ns"), nw.Float32(), nw.UInt32, nw.Date) + df.assert_selects(selector, "bbb", "def", "JJK") + selector = ncs.by_dtype( + nw.Int64, + nw.Int128, + nw.Duration("ns"), + nw.Int8, + nw.Binary(), + nw.Int32(), + nw.String, + ) + expected = "ncs.by_dtype([Binary, Duration(time_unit='ns'), Int128, Int32, Int64, Int8, String])" + assert_expr_ir_equal(selector, expected) + + +def test_selector_by_dtype_timezone_decimal() -> None: + df = Frame.from_mapping( + { + "idx": nw.Decimal(), + "dt1": nw.Datetime("ms"), + "dt2": nw.Datetime(time_zone="Asia/Tokyo"), + } + ) + df.assert_selects(ncs.by_dtype(nw.Decimal), "idx") + df.assert_selects(ncs.by_dtype(nw.Datetime(time_zone="Asia/Tokyo")), "dt2") + df.assert_selects(ncs.by_dtype(nw.Datetime("ms", None)), "dt1") + df.assert_selects(ncs.by_dtype(nw.Datetime), "dt1", "dt2") + + +def test_selector_by_dtype_empty(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + # empty selection selects nothing + df.assert_selects(ncs.by_dtype()) + df.assert_selects(ncs.by_dtype([])) + + +@pytest.mark.parametrize( + ("dtypes", "expected"), + [ + ( + [ + nw.Datetime, + nw.Enum, + nw.Datetime("s"), + nw.Duration, + nw.Struct, + nw.List, + nw.Array, + ], + ["l", "o", "q", "r", "s", "u"], + ), + ([nw.String(), nw.Boolean], ["k", "m"]), + ([nw.Datetime("ms"), nw.Date, nw.List(nw.String)], ["n", "s"]), + ( + [ + nw.Enum(["A", "B", "c"]), + nw.Struct({"a": nw.List(nw.Int64), "k": nw.String}), + ], + [], + ), + ], +) +def test_selector_by_dtype_mixed( + schema_mixed: nw.Schema, + dtypes: Iterable[DType | type[DType]], + expected: Iterable[str], +) -> None: + df = Frame(schema_mixed) + df.assert_selects(ncs.by_dtype(*dtypes), *expected) + df.assert_selects(ncs.by_dtype(dtypes), *expected) + + +def test_selector_by_dtype_invalid_input() -> None: + with pytest.raises(TypeError): + ncs.by_dtype(999) # type: ignore[arg-type] + + +def test_selector_by_index(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + # # one or more positive indices + df.assert_selects(ncs.by_index(0), "abc") + df.assert_selects(ncs.first(), "abc") + df.assert_selects(nwp.nth(0, 1, 2), "abc", "bbb", "cde") + df.assert_selects(ncs.by_index(0, 1, 2), "abc", "bbb", "cde") + + # one or more negative indices + df.assert_selects(ncs.by_index(-1), "qqR") + df.assert_selects(ncs.last(), "qqR") + df.assert_selects(ncs.by_index(-2, -1), "opp", "qqR") + + # range objects + df.assert_selects(ncs.by_index(range(3)), "abc", "bbb", "cde") + + # exclude by index + df.assert_selects( + ~ncs.by_index(range(0, df.width, 2)), "bbb", "def", "fgg", "JJK", "opp" + ) + + df.assert_selects(ncs.by_index(0, 999, require_all=False), "abc") + df.assert_selects(ncs.by_index(-1, -999, require_all=False), "qqR") + df.assert_selects(ncs.by_index(1234, 5678, require_all=False)) + + +def test_selector_by_index_invalid_input() -> None: + with pytest.raises(TypeError): + ncs.by_index("one") # type: ignore[arg-type] + + with pytest.raises(TypeError): + ncs.by_index(["two", "three"]) # type: ignore[list-item] + + +def test_selector_by_index_not_found(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + with pytest.raises(ColumnNotFoundError): + df.project(ncs.by_index(999)) + + df.assert_selects(ncs.by_index(999, -50, require_all=False)) + + df = Frame(nw.Schema()) + df.assert_selects(ncs.by_index(111, -112, require_all=False)) + + +def test_selector_by_index_reordering(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + df.assert_selects(ncs.by_index(-3, -2, -1), "Lmn", "opp", "qqR") + df.assert_selects(ncs.by_index(range(-3, 0)), "Lmn", "opp", "qqR") + df.assert_selects( + ncs.by_index(-3, 999, -2, -1, -48, require_all=False), "Lmn", "opp", "qqR" + ) + + +def test_selector_by_name(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + df.assert_selects(ncs.by_name("abc", "cde"), "abc", "cde") + + selector = ~ncs.by_name("abc", "cde", "ghi", "Lmn", "opp", "eee") + df.assert_selects(selector, "bbb", "def", "fgg", "JJK", "qqR") + df.assert_selects(ncs.by_name()) + df.assert_selects(ncs.by_name([])) + + df.assert_selects(ncs.by_name("???", "fgg", "!!!", require_all=False), "fgg") + + df.assert_selects(ncs.by_name("missing", require_all=False)) + df.assert_selects(ncs.by_name("???", require_all=False)) + + # check "by_name & col" + df.assert_selects(ncs.by_name("abc", "cde") & nwp.col("ghi")) + df.assert_selects(ncs.by_name("abc", "cde") & nwp.col("cde"), "cde") + df.assert_selects(ncs.by_name("cde") & ncs.by_name("cde", "abc"), "cde") + + # check "by_name & by_name" + selector = ncs.by_name("abc", "cde", "def", "eee") & ncs.by_name("cde", "eee", "fgg") + df.assert_selects(selector, "cde", "eee") + + +def test_selector_by_name_or_col(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + df.assert_selects(ncs.by_name("abc") | nwp.col("cde"), "abc", "cde") + + +def test_selector_by_name_not_found(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + with pytest.raises(ColumnNotFoundError): + df.project(ncs.by_name("xxx", "fgg", "!!!")) + + with pytest.raises(ColumnNotFoundError): + df.project(ncs.by_name("stroopwafel")) + + +def test_selector_by_name_invalid_input() -> None: + with pytest.raises(TypeError): + ncs.by_name(999) # type: ignore[arg-type] + + +def test_selector_first_last(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + first_name = "abc" + mid_names = "bbb", "cde", "def", "eee", "fgg", "ghi", "JJK", "Lmn", "opp" + last_name = "qqR" + + df.assert_selects(ncs.first(), first_name) + df.assert_selects(~ncs.first(), *mid_names, last_name) + df.assert_selects(ncs.last(), last_name) + df.assert_selects(~ncs.last(), first_name, *mid_names) + df.assert_selects(ncs.last() | ncs.first(), first_name, last_name) + + assert_expr_ir_equal(ncs.first(), "ncs.first()") + assert_expr_ir_equal(ncs.last(), "ncs.last()") + assert_expr_ir_equal(ncs.by_index(0), "ncs.first()") + assert_expr_ir_equal(ncs.by_index(-1), "ncs.last()") + + repr_other = repr(ncs.by_index(1)) + assert "ncs.by_index(" in repr_other + assert repr_other == repr_other.replace("ncs.first", "").replace("ncs.last", "") + + +def test_selector_datetime(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + df.assert_selects(ncs.datetime(), "opp") + df.assert_selects(ncs.datetime("ns")) + all_columns = list(df.columns) + all_columns.remove("opp") + df.assert_selects(~ncs.datetime(), *all_columns) + + +@pytest.mark.parametrize( + ("selector", "expected"), + [ + (ncs.datetime(), ("d1", "d2", "d3", "d4", "d5")), + (~ncs.datetime(), ()), + (ncs.datetime(["ms", "ns"]), ("d1", "d2", "d5")), + (ncs.datetime(["ms", "ns"], time_zone="*"), ("d1", "d2")), + (~ncs.datetime(["ms", "ns"]), ("d3", "d4")), + (~ncs.datetime(["ms", "ns"], time_zone="*"), ("d3", "d4", "d5")), + ( + ncs.datetime(time_zone=["UTC", "Asia/Tokyo", "Europe/London"]), + ("d1", "d2", "d3"), + ), + (ncs.datetime(time_zone="*"), ("d1", "d2", "d3")), + (ncs.datetime("ns", time_zone="*"), ("d1", "d2")), + (ncs.datetime(time_zone="UTC"), ("d2", "d3")), + (ncs.datetime("us", time_zone="UTC"), ("d3",)), + (ncs.datetime(time_zone="Asia/Tokyo"), ("d1",)), + (ncs.datetime("us", time_zone="Asia/Tokyo"), ()), + (ncs.datetime(time_zone=None), ("d4", "d5")), + (ncs.datetime("ns", time_zone=None), ()), + (~ncs.datetime(time_zone="*"), ("d4", "d5")), + (~ncs.datetime("ns", time_zone="*"), ("d3", "d4", "d5")), + (~ncs.datetime(time_zone="UTC"), ("d1", "d4", "d5")), + (~ncs.datetime("us", time_zone="UTC"), ("d1", "d2", "d4", "d5")), + (~ncs.datetime(time_zone="Asia/Tokyo"), ("d2", "d3", "d4", "d5")), + (~ncs.datetime("us", time_zone="Asia/Tokyo"), ("d1", "d2", "d3", "d4", "d5")), + (~ncs.datetime(time_zone=None), ("d1", "d2", "d3")), + (~ncs.datetime("ns", time_zone=None), ("d1", "d2", "d3", "d4", "d5")), + (ncs.datetime("ns"), ("d1", "d2")), + (ncs.datetime("us"), ("d3", "d4")), + (ncs.datetime("ms"), ("d5",)), + ], +) +def test_selector_datetime_exhaustive( + df_datetime: Frame, selector: Selector, expected: tuple[str, ...] +) -> None: + df = df_datetime + df.assert_selects(selector, *expected) + + +# NOTE: The test is *technically* passing, but the `TypeError` is being raised by `set(time_unit)` +# `TypeError: 'int' object is not iterable` +def test_selector_datetime_invalid_input() -> None: + with pytest.raises(TypeError): + ncs.datetime(999) # type: ignore[arg-type] + + +def test_selector_duration(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + + df.assert_selects(ncs.duration("ms")) + df.assert_selects(ncs.duration(["ms", "ns"])) + df.assert_selects(ncs.duration(), "Lmn") + + df = Frame.from_mapping( + {"d1": nw.Duration("ns"), "d2": nw.Duration("us"), "d3": nw.Duration("ms")} + ) + df.assert_selects(ncs.duration("us"), "d2") + df.assert_selects(ncs.duration(["ms", "ns"]), "d1", "d3") + df.assert_selects(ncs.duration(), "d1", "d2", "d3") + + +def test_selector_matches(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + # NOTE: python's `re` raises on the original pattern this test used + # > re.PatternError: global flags not at the start of the expression at position 1 + # https://github.com/pola-rs/polars/blob/84d66e960e3d462811f0575e0a6e4e78e34c618c/py-polars/tests/unit/test_selectors.py#L499 + pattern_str = r"(?i)[E-N]{3}" + # We can get closer though, by accepting pre-compiled pattern + pattern = re.compile(r"^[E-N]{3}$", re.IGNORECASE) + positive = "eee", "fgg", "ghi", "JJK", "Lmn" + negative = "abc", "bbb", "cde", "def", "opp", "qqR" + + df.assert_selects(ncs.matches(pattern_str), *positive) + df.assert_selects(ncs.matches(pattern), *positive) + + df.assert_selects(~ncs.matches(pattern_str), *negative) + df.assert_selects(~ncs.matches(pattern), *negative) + + +def test_selector_categorical(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + df.assert_selects(ncs.categorical()) + + df = Frame.from_mapping({"a": nw.String(), "b": nw.Binary(), "c": nw.Categorical()}) + df.assert_selects(ncs.categorical(), "c") + df.assert_selects(~ncs.categorical(), "a", "b") + + +def test_selector_numeric(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + df.assert_selects(ncs.numeric(), "abc", "bbb", "cde", "def") + df.assert_selects(ncs.numeric() - ncs.by_dtype(nw.UInt16), "bbb", "cde", "def") + df.assert_selects(~ncs.numeric(), "eee", "fgg", "ghi", "JJK", "Lmn", "opp", "qqR") + + +def test_selector_temporal(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + positive = "ghi", "JJK", "Lmn", "opp" + negative = "abc", "bbb", "cde", "def", "eee", "fgg", "qqR" + df.assert_selects(ncs.temporal(), *positive) + df.assert_selects(~ncs.temporal(), *negative) + + +def test_selector_float(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + positive = "cde", "def" + negative = "abc", "bbb", "eee", "fgg", "ghi", "JJK", "Lmn", "opp", "qqR" + df.assert_selects(ncs.float(), *positive) + df.assert_selects(~ncs.float(), *negative) + + +def test_selector_integer(schema_non_nested: nw.Schema) -> None: + df = Frame(schema_non_nested) + positive = "abc", "bbb" + negative = "cde", "def", "eee", "fgg", "ghi", "JJK", "Lmn", "opp", "qqR" + df.assert_selects(ncs.integer(), *positive) + df.assert_selects(~ncs.integer(), *negative) + + +def test_selector_expansion() -> None: + # https://github.com/pola-rs/polars/blob/84d66e960e3d462811f0575e0a6e4e78e34c618c/py-polars/tests/unit/test_selectors.py#L619 + df = Frame.from_names(*list("abcde")) + + s1 = nwp.all().meta.as_selector() + s2 = nwp.col(["a", "b"]).meta.as_selector() + s = s1 - s2 + df.assert_selects(s, "c", "d", "e") + + s1 = ncs.matches("^a|b$") + s = s1 | nwp.col(["d", "e"]).meta.as_selector() + df.assert_selects(s, "a", "b", "d", "e") + + s = s - nwp.col("d").meta.as_selector() + df.assert_selects(s, "a", "b", "e") + + # add a duplicate, this tests if they are pruned + s = s | nwp.col("a").meta.as_selector() + df.assert_selects(s, "a", "b", "e") + + s1e = nwp.col(["a", "b", "c"]) + s2e = nwp.col(["b", "c", "d"]) + + s = s1e.meta.as_selector() + s = s & s2e.meta.as_selector() + df.assert_selects(s, "b", "c") + + with pytest.raises( + InvalidOperationError, match=re_compile(r"cannot turn.+max.+into a selector") + ): + nwp.col("a").max().meta.as_selector() + + +def test_selector_set_ops(schema_non_nested: nw.Schema, schema_mixed: nw.Schema) -> None: + df = Frame(schema_non_nested) + + temporal = ncs.temporal() + + # or + selector = temporal | ncs.string() | ncs.matches(r"^e") + df.assert_selects(selector, "eee", "ghi", "JJK", "Lmn", "opp", "qqR") + + # and + selector = temporal & ncs.matches(r"opp|JJK") + df.assert_selects(selector, "JJK", "opp") + + # SET A - SET B + selector = temporal - ncs.matches(r"opp|JJK") + df.assert_selects(selector, "ghi", "Lmn") + # NOTE: `cs.exclude` was used, but `narwhals` doesn't have it + # Would allow: `str | Expr | DType | type[DType] | Selector | Collection[str | Expr | DType | type[DType] | Selector]` + selector = ncs.all() - (~temporal | ncs.matches(r"opp|JJK")) + df.assert_selects(selector, "ghi", "Lmn") + selector = nwp.all().exclude("opp", "JJK").meta.as_selector() - (~temporal) + df.assert_selects(selector, "ghi", "Lmn") + + sub_expr = ncs.matches("[yz]$") - nwp.col("colx") + assert_not_selector(sub_expr) + + with pytest.raises(TypeError, match=r"unsupported .* \('Expr' - 'Selector'\)"): + nwp.col("colx") - ncs.matches("[yz]$") + + # complement + selector = ~ncs.by_dtype([nw.Duration, nw.Time]) + df.assert_selects( + selector, "abc", "bbb", "cde", "def", "eee", "fgg", "JJK", "opp", "qqR" + ) + + # exclusive or + expected = "abc", "bbb", "eee", "fgg", "ghi" + df.assert_selects(ncs.matches("e|g") ^ ncs.numeric(), *expected) + df.assert_selects(ncs.matches(r"b|g") ^ nwp.col("eee"), *expected) + + df = Frame(schema_mixed) + selector = ~(ncs.numeric() | ncs.string()) + df.assert_selects(selector, "l", "m", "n", "o", "p", "q", "r", "s", "u") + + +def _is_binary_operator(function: OperatorFn) -> bool: + return function in {operator.and_, operator.or_, operator.xor} + + +def _is_selector_operator(function: OperatorFn) -> bool: + return function in {operator.and_, operator.or_, operator.xor, operator.sub} + + +@pytest.mark.parametrize( + "arg_2", + [1, nwp.col("a"), nwp.col("a").max(), ncs.numeric()], + ids=["Scalar", "Column", "Expr", "Selector"], +) +@pytest.mark.parametrize( + "function", [operator.and_, operator.or_, operator.xor, operator.add, operator.sub] +) +def test_selector_arith_binary_ops( + arg_2: IntoExpr | Selector, function: OperatorFn +) -> None: + # NOTE: These are the `polars.selectors` semantics + # Parts of it may change with `polars>=2.0`, due to how confusing they are + arg_1 = ncs.string() + result_1 = function(arg_1, arg_2) + if ( + _is_binary_operator(function) + and is_expr(arg_2) + and is_expr_ir_equal(arg_2, nwp.col("a")) + ) or (_is_selector_operator(function) and is_selector(arg_2)): + assert is_selector(result_1) + else: + assert_not_selector(result_1) + + if _is_binary_operator(function) and is_selector(arg_2): + result_2 = function(arg_2, arg_1) + assert is_selector(result_2) + # `__sub__` is allowed, but `__rsub__` is not ... + elif function is not operator.sub: + result_2 = function(arg_2, arg_1) + assert_not_selector(result_2) + # ... unless both are `Selector` + elif is_selector(arg_2): + result_2 = function(arg_2, arg_1) + assert is_selector(result_2) + else: + with pytest.raises(TypeError): + function(arg_2, arg_1) + + +@pytest.mark.parametrize( + "selector", + [ + (ncs.string() | ncs.numeric()), + (ncs.numeric() | ncs.string()), + ~(~ncs.numeric() & ~ncs.string()), + ~(~ncs.string() & ~ncs.numeric()), + (ncs.by_dtype(nw.Int16) ^ ncs.matches(r"b|e|q")) - ncs.matches("^e"), + ], +) +def test_selector_result_order(schema_non_nested: nw.Schema, selector: Selector) -> None: + df = Frame(schema_non_nested) + df.assert_selects(selector, "abc", "bbb", "cde", "def", "qqR") + + +def test_selector_list(schema_nested_1: nw.Schema) -> None: + df = Frame(schema_nested_1) + + # inner None + df.assert_selects(ncs.list(), "b", "c", "e") + # Inner All (as a DTypeSelector) + df.assert_selects(ncs.list(ncs.all()), "b", "c", "e") + # inner DTypeSelector + df.assert_selects(ncs.list(ncs.integer()), "b", "c") + df.assert_selects(ncs.list(inner=ncs.string()), "e") + # inner BinarySelector + df.assert_selects( + ncs.list(ncs.by_dtype(nw.Int32) | ncs.by_dtype(nw.UInt32)), "b", "c" + ) + # inner InvertSelector + df.assert_selects(ncs.list(~ncs.all())) + + +def test_selector_array(schema_nested_2: nw.Schema) -> None: + df = Frame(schema_nested_2) + df.assert_selects(ncs.array(), "b", "c", "d", "f") + df.assert_selects(ncs.array(ncs.all()), "b", "c", "d", "f") + df.assert_selects(ncs.array(size=4), "b", "c", "f") + df.assert_selects(ncs.array(inner=ncs.integer()), "b", "c", "d") + df.assert_selects(ncs.array(inner=ncs.string()), "f") + + +def test_selector_non_dtype_inside_dtype(schema_nested_2: nw.Schema) -> None: + df = Frame(schema_nested_2) + + with pytest.raises( + TypeError, match=r"expected datatype based expression got.+by_name\(" + ): + df.project(ncs.list(inner=ncs.by_name("???"))) + + with pytest.raises( + TypeError, match=r"expected datatype based expression got.+by_name\(" + ): + df.project(ncs.array(inner=ncs.by_name("???"))) + + +def test_selector_enum() -> None: + df = Frame.from_mapping( + { + "a": nw.Int32(), + "b": nw.UInt32(), + "c": nw_v1.Enum(), + "d": nw.Categorical(), + "e": nw.String(), + "f": nw.Enum(["a", "b"]), + } + ) + df.assert_selects(ncs.enum(), "c", "f") + df.assert_selects(~ncs.enum(), "a", "b", "d", "e") + + +def test_selector_struct() -> None: + df = Frame.from_mapping( + { + "a": nw.Int32(), + "b": nw.Array(nw.Int32, shape=(4,)), + "c": nw.Struct({}), + "d": nw.Array(nw.UInt32, shape=(4,)), + "e": nw.Struct({"x": nw.Int32, "y": nw.String}), + "f": nw.List(nw.Int32), + "g": nw.Array(nw.String, shape=(4,)), + "h": nw.Struct({"x": nw.Int32}), + } + ) + df.assert_selects(ncs.struct(), "c", "e", "h") + df.assert_selects(~ncs.struct(), "a", "b", "d", "f", "g") + + +def test_selector_matches_22816() -> None: + df = Frame.from_names("ham", "hamburger", "foo", "bar") + df.assert_selects(ncs.matches(r"^ham.*$"), "ham", "hamburger") + df.assert_selects(ncs.matches(r".*burger"), "hamburger") + + +def test_selector_by_name_order_19384() -> None: + df = Frame.from_names("a", "b") + df.assert_selects(ncs.by_name("b", "a"), "b", "a") + df.assert_selects(ncs.by_name("b", "a", require_all=False), "b", "a") + + +def test_selector_datetime_23767() -> None: + df = Frame.from_mapping( + {"a": nw.Datetime(), "b": nw.Datetime(time_zone=timezone.utc)} + ) + df.assert_selects(ncs.datetime("us", time_zone=None), "a") + df.assert_selects(ncs.datetime("us", time_zone=["UTC"]), "b") + df.assert_selects(ncs.datetime("us", time_zone=[None, "UTC"]), "a", "b") + + +def test_name_suffix_complex_selector(schema_mixed: nw.Schema) -> None: + df = Frame(schema_mixed) + selector = ( + ncs.all() - (ncs.categorical() | ncs.by_name("a", "b") | ncs.matches("[fqohim]")) + ^ ncs.by_name("u", "a", "b", "d", "e", "f", "g") + ).name.suffix("_after") + selected_names = "a", "b", "c", "f", "j", "k", "l", "n", "r", "s" + expecteds = (named_ir(f"{name}_after", nwp.col(name)) for name in selected_names) + actuals = df.project(selector) + + for actual, expected in zip_strict(actuals, expecteds): + assert_expr_ir_equal(actual, expected) + + +def test_name_map_chain_21164() -> None: + # https://github.com/pola-rs/polars/blob/5b90db75911c70010d0c0a6941046e6144af88d4/py-polars/tests/unit/operations/namespaces/test_name.py#L110-L115 + df = Frame.from_names("MyCol") + aliased = nwp.col("MyCol").alias("mycol_suffix") + rename_chain = ncs.all().name.to_lowercase().name.suffix("_suffix") + df.assert_selects(aliased, "mycol_suffix") + df.assert_selects(rename_chain, "mycol_suffix") + + +def test_when_then_keep_map_13858() -> None: + # https://github.com/pola-rs/polars/blob/aaa11d6af7383a5f9b62f432e14cc2d4af6d8548/py-polars/tests/unit/operations/namespaces/test_name.py#L118-L138 + # https://github.com/pola-rs/polars/issues/13858 + df = Frame.from_names("a", "b") + aliased = nwp.int_range(3).alias("b_other") + when_keep_chain = ( + nwp.when(nwp.lit(True)) + .then(nwp.int_range(nwp.len())) + .otherwise(1 + nwp.col("b")) + .name.keep() + .name.suffix("_other") + ) + df.assert_selects(aliased, "b_other") + df.assert_selects(when_keep_chain, "b_other") diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 8f857b694c..23bfcff219 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -1,11 +1,14 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any import pytest +import narwhals as nw from narwhals import _plan as nwp -from narwhals._plan import expressions as ir +from narwhals._plan import Expr, Selector, _expansion, _parse, expressions as ir +from narwhals._utils import qualified_type_name from tests.utils import assert_equal_data as _assert_equal_data pytest.importorskip("pyarrow") @@ -13,9 +16,132 @@ import pyarrow as pa if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + import sys + from collections.abc import Iterable, Mapping, Sequence - from typing_extensions import LiteralString + from typing_extensions import LiteralString, TypeAlias + + from narwhals._plan.typing import IntoExpr, OneOrIterable, Seq + from narwhals.typing import IntoSchema + + if sys.version_info >= (3, 11): + _Flags: TypeAlias = "int | re.RegexFlag" + else: + _Flags: TypeAlias = int + + +def first(*names: str | Sequence[str]) -> nwp.Expr: + return nwp.col(*names).first() + + +def last(*names: str | Sequence[str]) -> nwp.Expr: + return nwp.col(*names).last() + + +class Frame: + """Schema-only `{Expr,Selector}` projection testing tool. + + Arguments: + schema: A Narwhals Schema. + + Examples: + >>> import narwhals as nw + >>> import narwhals._plan.selectors as ncs + >>> df = Frame.from_mapping( + ... { + ... "abc": nw.UInt16(), + ... "bbb": nw.UInt32(), + ... "cde": nw.Float64(), + ... "def": nw.Float32(), + ... "eee": nw.Boolean(), + ... } + ... ) + + Determine the columns names that expression input would select + + >>> df.project_names(ncs.numeric() - ncs.by_index(1, 2)) + ('abc', 'def') + + Assert an expression selects names in a given order + + >>> df.assert_selects(ncs.by_name("eee", "abc"), "eee", "abc") + + Raising a helpful error if something went wrong + + >>> df.assert_selects(ncs.duration(), "eee", "abc") + Traceback (most recent call last): + AssertionError: Projected column names do not match expected names: + result : () + expected: ('eee', 'abc') + """ + + def __init__(self, schema: nw.Schema) -> None: + self.schema = schema + self.columns = tuple(schema.names()) + + @staticmethod + def from_mapping(mapping: IntoSchema) -> Frame: + """Construct from inputs accepted in `nw.Schema`.""" + return Frame(nw.Schema(mapping)) + + @staticmethod + def from_names(*column_names: str) -> Frame: + """Construct with all `nw.Int64()`.""" + return Frame(nw.Schema((name, nw.Int64()) for name in column_names)) + + @property + def width(self) -> int: + """Get the number of columns in the schema.""" + return len(self.columns) + + def project( + self, exprs: OneOrIterable[IntoExpr], *more_exprs: IntoExpr + ) -> Seq[ir.NamedIR]: + """Parse and expand expressions into named representations. + + Arguments: + exprs: Column(s) to select. Accepts expression input. Strings are parsed as column names, + other non-expression inputs are parsed as literals. + *more_exprs: Column(s) to select, specified as positional arguments. + + Note: + `NamedIR` is the form of expression passed to the compliant-level. + + Examples: + >>> import datetime as dt + >>> import narwhals._plan.selectors as ncs + >>> df = Frame.from_names("a", "b", "c", "d", "idx1", "idx2") + >>> expr_1 = ( + ... ncs.by_name("a", "d") + ... .first() + ... .over(ncs.by_index(range(1, 4)), order_by=ncs.matches(r"idx")) + ... ) + >>> expr_2 = (ncs.by_name("a") | ncs.by_index(2)).abs().name.suffix("_abs") + >>> expr_3 = dt.date(2000, 1, 1) + + >>> df.project(expr_1, expr_2, expr_3) # doctest: +NORMALIZE_WHITESPACE + (a=col('a').first().over(partition_by=[col('b'), col('c'), col('d')], order_by=[col('idx1'), col('idx2')]), + d=col('d').first().over(partition_by=[col('b'), col('c'), col('d')], order_by=[col('idx1'), col('idx2')]), + a_abs=col('a').abs(), + c_abs=col('c').abs(), + literal=lit(date: 2000-01-01)) + """ + expr_irs = _parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) + named_irs, _ = _expansion.prepare_projection(expr_irs, schema=self.schema) + return named_irs + + def project_names(self, *exprs: IntoExpr) -> Seq[str]: + named_irs = self.project(*exprs) + return tuple(e.name for e in named_irs) + + def assert_selects(self, selector: Selector | Expr, *column_names: str) -> None: + result = self.project_names(selector) + expected = column_names + assert result == expected, ( + f"Projected column names do not match expected names:\n" + f"result : {result!r}\n" + f"expected: {expected!r}" + ) def _unwrap_ir(obj: nwp.Expr | ir.ExprIR | ir.NamedIR) -> ir.ExprIR: @@ -57,6 +183,25 @@ def assert_expr_ir_equal( assert lhs == rhs, f"\nlhs:\n {lhs!r}\n\nrhs:\n {rhs!r}" +def assert_not_selector(actual: Expr | Selector, /) -> None: + """Assert that `actual` was converted into an `Expr`.""" + assert isinstance(actual, Expr), ( + f"Didn't expect you to pass a {qualified_type_name(actual)!r} here, got: {actual!r}" + ) + assert not isinstance(actual, Selector), ( + f"This operation should have returned `Expr`, but got {qualified_type_name(actual)!r}\n{actual!r}" + ) + + +def is_expr_ir_equal(actual: Expr | ir.ExprIR, expected: Expr | ir.ExprIR, /) -> bool: + """Return True if `actual` is equivalent to `expected`. + + Note: + Prefer `assert_expr_ir_equal` unless you need a `bool` for branching. + """ + return _unwrap_ir(actual) == _unwrap_ir(expected) + + def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name) @@ -74,3 +219,13 @@ def assert_equal_data( result: nwp.DataFrame[Any, Any], expected: Mapping[str, Any] ) -> None: _assert_equal_data(result.to_dict(as_series=False), expected) + + +def re_compile( + pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE +) -> re.Pattern[str]: + """Compile a regular expression pattern, returning a Pattern object. + + Helper to default to using `flags=re.DOTALL | re.IGNORECASE`. + """ + return re.compile(pattern, flags) diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index a3049eabd2..c5a8c135ab 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -8,6 +8,7 @@ import sys import tempfile from pathlib import Path +from subprocess import CompletedProcess def extract_docstring_examples(files: list[str]) -> list[tuple[Path, str, str]]: @@ -48,18 +49,23 @@ def create_temp_files(examples: list[tuple[Path, str, str]]) -> list[tuple[Path, return temp_files -def run_ruff_on_temp_files(temp_files: list[tuple[Path, str]]) -> list[str]: +def run_ruff_on_temp_files( + temp_files: list[tuple[Path, str]], +) -> CompletedProcess[str] | None: """Run ruff on all temporary files and collect error messages.""" + from ruff.__main__ import find_ruff_bin + temp_file_paths = [temp_file[0] for temp_file in temp_files] + ruff_bin = find_ruff_bin() result = subprocess.run( # noqa: S603 - [ # noqa: S607 - "python", - "-m", - "ruff", + [ + ruff_bin, "check", "--select=F", - "--ignore=F811", + # > (F821) Undefined name + # Not how doctests work + "--ignore=F811,F821", *temp_file_paths, ], capture_output=True, @@ -68,21 +74,22 @@ def run_ruff_on_temp_files(temp_files: list[tuple[Path, str]]) -> list[str]: ) if result.returncode == 0: - return [] # No issues found - return result.stdout.splitlines() # Return ruff errors as a list of lines + return None + return result -def report_errors(errors: list[str], temp_files: list[tuple[Path, str]]) -> None: +def report_errors( + completed: CompletedProcess[str] | None, temp_files: list[tuple[Path, str]] +) -> None: """Map errors back to original examples and report them.""" - if not errors: + if completed is None: return - print("❌ Ruff issues found in examples:\n") - for line in errors: - for temp_file, original_context in temp_files: - if str(temp_file) in line: - print(f"{original_context}{line.replace(str(temp_file), '')}") - break + print("Ruff issues found in examples:\n") + stdout = completed.stdout + for temp_file, original_context in temp_files: + stdout = stdout.replace(str(temp_file), original_context) + print(stdout) def cleanup_temp_files(temp_files: list[tuple[Path, str]]) -> None: