Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
cb470b4
refactor: Use `temp.column_name(s)` some more
dangotbanned Oct 1, 2025
23e9d43
fix(typing): Resolve some cases for `flatten_hash_safe`
dangotbanned Oct 1, 2025
f77bb4c
feat(expr-ir): Impl `acero.sort_by`
dangotbanned Oct 2, 2025
36ddce0
test: Port over `is_first_distinct` tests
dangotbanned Oct 2, 2025
0e49f57
chore: Add `Compliant{Expr,Scalar}.is_{first,last}_distinct`
dangotbanned Oct 2, 2025
a5f192c
test: Update to cover `is_last_distinct` as well
dangotbanned Oct 2, 2025
6a1b08a
feat(DRAFT): Initial `is_first_distinct` impl
dangotbanned Oct 2, 2025
1c026bf
test: Port over more cases
dangotbanned Oct 3, 2025
e7e8a04
refactor: Generalize `is_first_distinct` impl
dangotbanned Oct 3, 2025
2d46521
feat: Add `is_last_distinct`
dangotbanned Oct 3, 2025
cfb775d
refactor: Make both `is_*_distinct` methods, aliases
dangotbanned Oct 3, 2025
9db603b
feat: (Properly) add `get_column`, `to_series`
dangotbanned Oct 3, 2025
f8255d3
chore: Add `pc.is_in` wrapper
dangotbanned Oct 3, 2025
6fe2a0a
docs: Add detail to `FunctionFlags.LENGTH_PRESERVING`
dangotbanned Oct 3, 2025
938befb
test: More test porting
dangotbanned Oct 3, 2025
516f4a6
typo
dangotbanned Oct 3, 2025
ead4e62
feat(DRAFT): Some progress on `hashjoin` port
dangotbanned Oct 4, 2025
273bdcc
fix: Correctly pass down join keys
dangotbanned Oct 5, 2025
ce37617
test: Port over inner, left & clean up
dangotbanned Oct 5, 2025
18ef26a
test: Add `test_suffix`
dangotbanned Oct 5, 2025
94baf1e
test: Add `how="cross"` tests
dangotbanned Oct 5, 2025
733b45a
test: Add `how={"anti","semi"}` tests
dangotbanned Oct 5, 2025
ce321e0
test: replace `"antananarivo"`->`"a"`, `"bob"`->`"b"`
dangotbanned Oct 5, 2025
cc0d379
test: Port the other duplicate test
dangotbanned Oct 5, 2025
dd40e3a
test: Make all the xfails more visible
dangotbanned Oct 5, 2025
d1a1785
feat(DRAFT): Initial acero cross-join impl
dangotbanned Oct 5, 2025
77e55b3
refactor: Only expose `acero.join_tables`
dangotbanned Oct 5, 2025
8f7d2f3
chore: Start factoring-out `Table` dependency
dangotbanned Oct 5, 2025
b0c2a4d
Merge branch 'oh-nodes' into expr-ir/acero-order-by
dangotbanned Oct 6, 2025
d42f5de
refactor(typing): Use `IntoExprColumn` some more
dangotbanned Oct 6, 2025
b8a58c1
refactor: Split up `_parse_sort_by`
dangotbanned Oct 6, 2025
05c63fd
Make a start on `DataFrame.filter`
dangotbanned Oct 6, 2025
025213d
fill out slightly more `filter`
dangotbanned Oct 6, 2025
3e94449
get typing working again (kinda)
dangotbanned Oct 6, 2025
a611bc9
feat(DRAFT): Support `filter(list[bool])`
dangotbanned Oct 6, 2025
d514ad0
feat: Support single `Series` as well
dangotbanned Oct 6, 2025
d452920
test: Use `parametrize`
dangotbanned Oct 6, 2025
4c7c23d
feat: Add predicate expansion
dangotbanned Oct 6, 2025
2ebca30
feat(expr-ir): Full `DataFrame.filter` support
dangotbanned Oct 6, 2025
1b66786
test: Merge the anti/semi tests
dangotbanned Oct 6, 2025
fd38911
test: parametrize exception messages
dangotbanned Oct 6, 2025
3537cac
test: relax more error messages
dangotbanned Oct 6, 2025
b5ef86b
typo
dangotbanned Oct 7, 2025
8433b2d
test: Add `test_filter_mask_mixed`
dangotbanned Oct 7, 2025
7668abb
fix: Raise on duplicate column names
dangotbanned Oct 7, 2025
3ca43d1
cov
dangotbanned Oct 7, 2025
0f06479
perf: Avoid multiple collections during cross join
dangotbanned Oct 7, 2025
7e9ee74
test: Stop repeating the same data so many times
dangotbanned Oct 7, 2025
1523dbb
test: Add some cases from polars
dangotbanned Oct 8, 2025
a479f32
fix: typing mypy
dangotbanned Oct 8, 2025
8e840e0
feat(expr-ir): Full-er `DataFrame.filter` support
dangotbanned Oct 8, 2025
af26916
refactor: Simplify the `NonCrossJoinStrategy` split
dangotbanned Oct 8, 2025
6aaf75d
test: Convert raising test into a conformance test
dangotbanned Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import pyarrow.compute as pc # ignore-banned-import
from pyarrow.acero import Declaration as Decl

from narwhals._plan.common import flatten_hash_safe
from narwhals._plan.options import SortMultipleOptions
from narwhals._plan.typing import OneOrSeq
from narwhals.typing import SingleColSelector

Expand Down Expand Up @@ -189,10 +191,15 @@ def _order_by(
return Decl("order_by", pac.OrderByNodeOptions(keys, null_placement=null_placement))


# TODO @dangotbanned: Utilize `SortMultipleOptions.to_arrow_acero`
def sort_by(*args: Any, **kwds: Any) -> Decl:
msg = "Should convert from polars args -> use `_order_by"
raise NotImplementedError(msg)
def sort_by(
by: OneOrIterable[str],
*more_by: str,
descending: OneOrIterable[bool] = False,
nulls_last: bool = False,
) -> Decl:
return SortMultipleOptions.parse(
descending=descending, nulls_last=nulls_last
).to_arrow_acero(tuple(flatten_hash_safe((by, more_by))))
Comment on lines +261 to +269
Copy link
Member Author

@dangotbanned dangotbanned Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of feat(expr-ir): Impl acero.sort_by, I still need to make use of this in a plan.

A good candidate might be in either/both of

over(order_by=...)

def over_ordered(
self, node: ir.OrderedWindowExpr, frame: Frame, name: str
) -> Self | Scalar:
if node.partition_by:
msg = f"Need to implement `group_by`, `join` for:\n{node!r}"
raise NotImplementedError(msg)
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort`
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by)
options = node.sort_options.to_multiple(len(node.order_by))
idx_name = temp.column_name(frame)
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options)
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name)
if isinstance(evaluated, ArrowScalar):
# NOTE: We're already sorted, defer broadcasting to the outer context
# Wouldn't be suitable for partitions, but will be fine here
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/2ae42458cae91f4473e01270919815fcd7cb9667
# - https://github.com/narwhals-dev/narwhals/pull/2528/commits/b8066c4c57d4b0b6c38d58a0f5de05eefc2cae70
return self._with_native(evaluated.native, name)
indices = pc.sort_indices(sorted_context.get_column(idx_name).native)
height = len(sorted_context)
result = evaluated.broadcast(height).native.take(indices)
return self._with_native(result, name)

is_{first,last}_distinct

def is_first_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
first_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "min")])
.column(f"{col_token}_min")
)
return self._with_native(pc.is_in(row_number, first_distinct_index))
def is_last_distinct(self) -> Self:
import numpy as np # ignore-banned-import
row_number = pa.array(np.arange(len(self)))
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
last_distinct_index = (
pa.Table.from_arrays([self.native], names=[self.name])
.append_column(col_token, row_number)
.group_by(self.name)
.aggregate([(col_token, "max")])
.column(f"{col_token}_max")
)
return self._with_native(pc.is_in(row_number, last_distinct_index))



def collect(*declarations: Decl, use_threads: bool = True) -> pa.Table:
Expand Down
17 changes: 5 additions & 12 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,13 @@
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.series import ArrowSeries as Series
from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co
from narwhals._plan.common import temp
from narwhals._plan.compliant.column import ExprDispatch
from narwhals._plan.compliant.expr import EagerExpr
from narwhals._plan.compliant.scalar import EagerScalar
from narwhals._plan.compliant.typing import namespace
from narwhals._plan.expressions import NamedIR
from narwhals._utils import (
Implementation,
Version,
_StoresNative,
generate_temporary_column_name,
not_implemented,
)
from narwhals._utils import Implementation, Version, _StoresNative, not_implemented
from narwhals.exceptions import InvalidOperationError, ShapeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -231,10 +226,8 @@ def sort(self, node: ir.Sort, frame: Frame, name: str) -> Expr:

def sort_by(self, node: ir.SortBy, frame: Frame, name: str) -> Expr:
series = self._dispatch_expr(node.expr, frame, name)
by = (
self._dispatch_expr(e, frame, f"<TEMP>_{idx}")
for idx, e in enumerate(node.by)
)
it_names = temp.column_names(frame)
by = (self._dispatch_expr(e, frame, nm) for e, nm in zip(node.by, it_names))
df = namespace(self)._concat_horizontal((series, *by))
names = df.columns[1:]
indices = pc.sort_indices(df.native, options=node.options.to_arrow(names))
Expand Down Expand Up @@ -342,7 +335,7 @@ def over_ordered(
# NOTE: Converting `over(order_by=..., options=...)` into the right shape for `DataFrame.sort`
sort_by = tuple(NamedIR.from_ir(e) for e in node.order_by)
options = node.sort_options.to_multiple(len(node.order_by))
idx_name = generate_temporary_column_name(8, frame.columns)
idx_name = temp.column_name(frame)
sorted_context = frame.with_row_index(idx_name).sort(sort_by, options)
evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name)
if isinstance(evaluated, ArrowScalar):
Expand Down
16 changes: 15 additions & 1 deletion narwhals/_plan/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from typing_extensions import TypeIs

from narwhals._plan.compliant.series import CompliantSeries
from narwhals._plan.series import Series
from narwhals._plan.typing import (
DTypeT,
ExprIRT,
Expand Down Expand Up @@ -109,9 +111,21 @@ def into_dtype(dtype: DTypeT | type[NonNestedDTypeT], /) -> DTypeT | NonNestedDT
return dtype


# TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021)
# NOTE: See (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021)
# The issue is `T` possibly being `Iterable`
# Ignoring here still leaks the issue to the caller, where you need to annotate the base case
@overload
def flatten_hash_safe(iterable: Iterable[OneOrIterable[str]], /) -> Iterator[str]: ...
Comment on lines +115 to +119
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an improvement over the previous version, but far from ideal.

Still doesn't resolve this case, and I'm not entirely sure why yet

@classmethod
def align(
cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]]
) -> Iterator[SeriesT]:
exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs))
length = cls._length_required(exprs)
if length is None:
for e in exprs:
yield e.to_series()
else:
for e in exprs:
yield e.broadcast(length)

@overload
def flatten_hash_safe(
iterable: Iterable[OneOrIterable[Series]], /
) -> Iterator[Series]: ...
@overload
def flatten_hash_safe(
iterable: Iterable[OneOrIterable[CompliantSeries]], /
) -> Iterator[CompliantSeries]: ...
@overload
def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: ...
def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]:
"""Fully unwrap all levels of nesting.

Expand Down
3 changes: 1 addition & 2 deletions narwhals/_plan/expressions/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ class Exclude(_ColumnSelection, child=("expr",)):

@staticmethod
def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude:
flat: t.Iterator[str] = flatten_hash_safe(names)
return Exclude(expr=expr, names=tuple(flat))
return Exclude(expr=expr, names=tuple(flatten_hash_safe(names)))

def __repr__(self) -> str:
return f"{self.expr!r}.exclude({list(self.names)!r})"
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_plan/expressions/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from narwhals._utils import Version, _parse_time_unit_and_time_zone

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import timezone
from typing import TypeVar

Expand Down Expand Up @@ -127,7 +126,7 @@ def from_string(pattern: str, /) -> Matches:
@staticmethod
def from_names(*names: OneOrIterable[str]) -> Matches:
"""Implements `cs.by_name` to support `__r<op>__` with column selections."""
it: Iterator[str] = flatten_hash_safe(names)
it = flatten_hash_safe(names)
return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$")

def __repr__(self) -> str:
Expand Down
Loading