Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 32 additions & 114 deletions narwhals/_plan/expr_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@

from collections import deque
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING

from narwhals._plan import common
from narwhals._plan import common, meta
from narwhals._plan._immutable import Immutable
from narwhals._plan.common import ExprIR, NamedIR, SelectorIR, is_horizontal_reduction
from narwhals._plan.exceptions import (
Expand Down Expand Up @@ -74,11 +73,10 @@
from narwhals.exceptions import ComputeError, InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from collections.abc import Iterable, Iterator, Sequence

from typing_extensions import TypeAlias

from narwhals._plan.dummy import Expr
from narwhals._plan.typing import Seq
from narwhals.dtypes import DType

Expand Down Expand Up @@ -148,10 +146,6 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags:
has_exclude=has_exclude,
)

@classmethod
def from_expr(cls, expr: Expr, /) -> ExpansionFlags:
return cls.from_ir(expr._ir)

def with_multiple_columns(self) -> ExpansionFlags:
return common.replace(self, multiple_columns=True)

Expand Down Expand Up @@ -189,7 +183,7 @@ def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]:
def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames:
"""Raise an appropriate error if we can't materialize."""
output_names = _ensure_output_names_unique(exprs)
root_names = _root_names_unique(exprs)
root_names = meta.root_names_unique(exprs)
if not (set(schema.names).issuperset(root_names)):
raise column_not_found_error(root_names, schema)
return output_names
Expand All @@ -202,13 +196,6 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames:
return names


def _root_names_unique(exprs: Seq[ExprIR]) -> set[str]:
from narwhals._plan.meta import _expr_to_leaf_column_names_iter

it = chain.from_iterable(_expr_to_leaf_column_names_iter(expr) for expr in exprs)
return set(it)


def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR:
def fn(child: ExprIR, /) -> ExprIR:
if is_horizontal_reduction(child):
Expand Down Expand Up @@ -239,18 +226,7 @@ def is_index_in_range(index: int, n_fields: int) -> bool:

def remove_alias(origin: ExprIR, /) -> ExprIR:
def fn(child: ExprIR, /) -> ExprIR:
if isinstance(child, Alias):
return child.expr
return child

return origin.map_ir(fn)


def remove_exclude(origin: ExprIR, /) -> ExprIR:
def fn(child: ExprIR, /) -> ExprIR:
if isinstance(child, Exclude):
return child.expr
return child
return child.expr if isinstance(child, Alias) else child

return origin.map_ir(fn)

Expand All @@ -263,20 +239,14 @@ def replace_with_column(
def fn(child: ExprIR, /) -> ExprIR:
if isinstance(child, tp):
return col(name)
if isinstance(child, Exclude):
return child.expr
return child
return child.expr if isinstance(child, Exclude) else child

return origin.map_ir(fn)


def replace_selector(ir: ExprIR, /, *, schema: FrozenSchema) -> ExprIR:
"""Fully diverging from `polars`, we'll see how that goes."""

def fn(child: ExprIR, /) -> ExprIR:
if isinstance(child, SelectorIR):
return expand_selector(child, schema=schema)
return child
return expand_selector(child, schema) if isinstance(child, SelectorIR) else child

return ir.map_ir(fn)

Expand All @@ -293,7 +263,7 @@ def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) ->


@lru_cache(maxsize=100)
def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns:
def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns:
"""Expand `selector` into `Columns`, within the context of `schema`."""
matches = selector_matches_column
return cols(*(k for k, v in schema.items() if matches(selector, k, v)))
Expand All @@ -313,64 +283,46 @@ def rewrite_projections(
if flags.has_selector:
expanded = replace_selector(expanded, schema=schema)
flags = flags.with_multiple_columns()
result.extend(
replace_and_add_to_results(
expanded, keys=keys, col_names=schema.names, flags=flags
)
)
result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags))
return tuple(result)


def replace_and_add_to_results(
def iter_replace(
origin: ExprIR,
/,
keys: GroupByKeys,
*,
col_names: FrozenColumns,
flags: ExpansionFlags,
) -> Seq[ExprIR]:
result: deque[ExprIR] = deque()
) -> Iterator[ExprIR]:
if flags.has_nth:
origin = replace_nth(origin, col_names)
if flags.expands:
it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns)))
if e := next(it, None):
if isinstance(e, Columns):
exclude = prepare_excluded(
origin, keys=keys, has_exclude=flags.has_exclude
)
result.extend(expand_columns(origin, e, exclude=exclude))
if not _all_columns_match(origin, e):
msg = "expanding more than one `col` is not allowed"
raise ComputeError(msg)
names: Iterable[str] = e.names
else:
exclude = prepare_excluded(
origin, keys=keys, has_exclude=flags.has_exclude
)
result.extend(
expand_indices(origin, e, col_names=col_names, exclude=exclude)
)
names = _iter_index_names(e, col_names)
exclude = prepare_excluded(origin, keys, flags)
yield from expand_column_selection(origin, type(e), names, exclude)
elif flags.has_wildcard:
exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude)
result.extend(replace_wildcard(origin, col_names=col_names, exclude=exclude))
exclude = prepare_excluded(origin, keys, flags)
yield from expand_column_selection(origin, All, col_names, exclude)
else:
exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude)
expanded = rewrite_special_aliases(origin)
result.append(expanded)
return tuple(result)


def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]:
"""Yield all excluded names in `origin`."""
for e in origin.iter_left():
if isinstance(e, Exclude):
yield from e.names
yield rewrite_special_aliases(origin)


def prepare_excluded(
origin: ExprIR, /, keys: GroupByKeys, *, has_exclude: bool
origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, /
) -> Excluded:
"""Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555."""
exclude: set[str] = set()
if has_exclude:
exclude.update(_iter_exclude_names(origin))
if flags.has_exclude:
exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude)))
for group_by_key in keys:
if name := group_by_key.meta.output_name(raise_if_undetermined=False):
exclude.add(name)
Expand All @@ -382,52 +334,20 @@ def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool:
return all(it)


def expand_columns(
origin: ExprIR, /, columns: Columns, *, exclude: Excluded
) -> Seq[ExprIR]:
if not _all_columns_match(origin, columns):
msg = "expanding more than one `col` is not allowed"
raise ComputeError(msg)
result: deque[ExprIR] = deque()
for name in columns.names:
if name not in exclude:
expanded = replace_with_column(origin, Columns, name)
expanded = rewrite_special_aliases(expanded)
result.append(expanded)
return tuple(result)


def expand_indices(
origin: ExprIR,
/,
indices: IndexColumns,
*,
col_names: FrozenColumns,
exclude: Excluded,
) -> Seq[ExprIR]:
result: deque[ExprIR] = deque()
n_fields = len(col_names)
def _iter_index_names(indices: IndexColumns, names: FrozenColumns, /) -> Iterator[str]:
n_fields = len(names)
for index in indices.indices:
if not is_index_in_range(index, n_fields):
raise column_index_error(index, col_names)
name = col_names[index]
if name not in exclude:
expanded = replace_with_column(origin, IndexColumns, name)
expanded = rewrite_special_aliases(expanded)
result.append(expanded)
return tuple(result)
raise column_index_error(index, names)
yield names[index]


def replace_wildcard(
origin: ExprIR, /, *, col_names: FrozenColumns, exclude: Excluded
) -> Seq[ExprIR]:
result: deque[ExprIR] = deque()
for name in col_names:
def expand_column_selection(
origin: ExprIR, tp: type[_ColumnSelection], /, names: Iterable[str], exclude: Excluded
) -> Iterator[ExprIR]:
for name in names:
if name not in exclude:
expanded = replace_with_column(origin, All, name)
expanded = rewrite_special_aliases(expanded)
result.append(expanded)
return tuple(result)
yield rewrite_special_aliases(replace_with_column(origin, tp, name))


def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR:
Expand All @@ -438,8 +358,6 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR:
- Expanding all selections into `Column`
- Dealing with `FunctionExpr.input`
"""
from narwhals._plan import meta

if meta.has_expr_ir(origin, KeepName, RenameAlias):
if isinstance(origin, KeepName):
parent = origin.expr
Expand Down
7 changes: 6 additions & 1 deletion narwhals/_plan/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from __future__ import annotations

from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Literal, overload

from narwhals._plan.common import IRNamespace
from narwhals.exceptions import ComputeError
from narwhals.utils import Version

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Iterable, Iterator

from typing_extensions import TypeIs

Expand Down Expand Up @@ -121,6 +122,10 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError:
return ComputeError(msg)


def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]:
return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in irs))


@lru_cache(maxsize=32)
def _expr_output_name(ir: ExprIR) -> str | ComputeError:
from narwhals._plan import expr
Expand Down
Loading