Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions narwhals/_plan/arrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from narwhals._plan.common import ensure_list_str, temp
from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq
from narwhals._utils import check_column_names_are_unique
from narwhals.typing import JoinStrategy, SingleColSelector
from narwhals.typing import AsofJoinStrategy, JoinStrategy, SingleColSelector

if TYPE_CHECKING:
from collections.abc import (
Expand All @@ -47,7 +47,12 @@
Aggregation as _Aggregation,
)
from narwhals._plan.arrow.group_by import AggSpec
from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny
from narwhals._plan.arrow.typing import (
ArrowAny,
ChunkedArrayAny,
JoinTypeSubset,
ScalarAny,
)
from narwhals._plan.typing import OneOrIterable, Seq
from narwhals.typing import NonNestedLiteral

Expand Down Expand Up @@ -278,6 +283,53 @@ def _hashjoin(
return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)])


def _join_asof_suffix_collisions(
left: pa.Table, right: pa.Table, right_on: str, right_by: Sequence[str], suffix: str
) -> pa.Table:
"""Adapted from [upstream] to avoid raising early.

[upstream]: https://github.com/apache/arrow/blob/9b03118e834dfdaa0cf9e03595477b499252a9cb/python/pyarrow/acero.py#L306-L316
"""
right_names = right.schema.names
allowed = {right_on, *right_by}
if collisions := set(right_names).difference(allowed).intersection(left.schema.names):
renamed = [f"{nm}{suffix}" if nm in collisions else nm for nm in right_names]
return right.rename_columns(renamed)
return right


def _join_asof_strategy_to_tolerance(
left_on: ChunkedArrayAny, right_on: ChunkedArrayAny, /, strategy: AsofJoinStrategy
) -> int:
"""Calculate the **required** `tolerance` argument, from `*on` values and strategy.

For both `polars` and `pandas` this is optional and in `narwhals` it isn't supported (yet).

So we need to get the lowest/highest value for a match and use that for a similar default.

`"backward"`:

tolerance <= right_on - left_on <= 0

`"forward"`:

0 <= right_on - left_on <= tolerance

Note:
`tolerance` is interpreted in the same units as the `on` keys.
"""
import narwhals._plan.arrow.functions as fn

if strategy == "nearest":
msg = "Only 'backward' and 'forward' strategies are currently supported for `pyarrow`"
raise NotImplementedError(msg)
lower = fn.min_horizontal(fn.min_(left_on), fn.min_(right_on))
upper = fn.max_horizontal(fn.max_(left_on), fn.max_(right_on))
scalar = fn.sub(lower, upper) if strategy == "backward" else fn.sub(upper, lower)
tolerance: int = fn.cast(scalar, fn.I64).as_py()
return tolerance


def declare(*declarations: Decl) -> Decl:
"""Compose one or more `Declaration` nodes for execution as a pipeline."""
if len(declarations) == 1:
Expand Down Expand Up @@ -439,6 +491,31 @@ def join_cross_tables(
return collect(decl, ensure_unique_column_names=True).remove_column(0)


def join_asof_tables(
left: pa.Table,
right: pa.Table,
left_on: str,
right_on: str,
*,
left_by: Sequence[str] = (),
right_by: Sequence[str] = (),
strategy: AsofJoinStrategy = "backward",
suffix: str = "_right",
) -> pa.Table:
"""Perform an inexact join between two tables, using the nearest key."""
right = _join_asof_suffix_collisions(left, right, right_on, right_by, suffix=suffix)
tolerance = _join_asof_strategy_to_tolerance(
left.column(left_on), right.column(right_on), strategy
)
lb: list[Any] = [] if not left_by else list(left_by)
rb: list[Any] = [] if not right_by else list(right_by)
join_opts = pac.AsofJoinNodeOptions(
left_on=left_on, right_on=right_on, left_by=lb, right_by=rb, tolerance=tolerance
)
inputs = [table_source(left), table_source(right)]
return Decl("asofjoin", join_opts, inputs).to_table()


def _add_column_table(
native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny
) -> pa.Table:
Expand Down
4 changes: 4 additions & 0 deletions narwhals/_plan/arrow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class ArrowFrameSeries(Generic[NativeT]):
_native: NativeT
_version: Version

# NOTE: Aliases to integrate with `@requires.backend_version`
_backend_version = compat.BACKEND_VERSION
_implementation = implementation

@property
def native(self) -> NativeT:
return self._native
Expand Down
29 changes: 27 additions & 2 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from narwhals._plan.compliant.typing import LazyFrameAny, namespace
from narwhals._plan.exceptions import shape_error
from narwhals._plan.expressions import NamedIR, named_ir
from narwhals._utils import Version, generate_repr
from narwhals._utils import Version, generate_repr, requires
from narwhals.schema import Schema

if TYPE_CHECKING:
Expand All @@ -37,7 +37,7 @@
from narwhals._plan.typing import NonCrossJoinStrategy
from narwhals._typing import _LazyAllowedImpl
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy
from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy

Incomplete: TypeAlias = Any

Expand Down Expand Up @@ -308,6 +308,31 @@ def join_inner(self, other: Self, on: list[str], /) -> Self:
"""Less flexible, but more direct equivalent to join(how="inner", left_on=...)`."""
return self._with_native(acero.join_inner_tables(self.native, other.native, on))

@requires.backend_version((16,))
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
left_by: Sequence[str] = (),
right_by: Sequence[str] = (),
strategy: AsofJoinStrategy = "backward",
suffix: str = "_right",
) -> Self:
return self._with_native(
acero.join_asof_tables(
self.native,
other.native,
left_on,
right_on,
left_by=left_by,
right_by=right_by,
strategy=strategy,
suffix=suffix,
)
)

def _filter(self, predicate: Predicate | acero.Expr) -> Self:
mask: Incomplete = predicate
return self._with_native(self.native.filter(mask))
Expand Down
33 changes: 22 additions & 11 deletions narwhals/_plan/compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy
from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy

Incomplete: TypeAlias = Any

Expand All @@ -72,6 +72,27 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: ...
# Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around
def filter(self, predicate: NamedIR, /) -> Self: ...
def join(
self,
other: Self,
*,
how: NonCrossJoinStrategy,
left_on: Sequence[str],
right_on: Sequence[str],
suffix: str = "_right",
) -> Self: ...
def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
left_by: Sequence[str] = (), # https://github.com/pola-rs/polars/issues/18496
right_by: Sequence[str] = (),
strategy: AsofJoinStrategy = "backward",
suffix: str = "_right",
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
@property
def schema(self) -> Mapping[str, DType]: ...
Expand Down Expand Up @@ -213,16 +234,6 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se

def filter(self, predicate: NamedIR, /) -> Self: ...
def iter_columns(self) -> Iterator[SeriesT]: ...
def join(
self,
other: Self,
*,
how: NonCrossJoinStrategy,
left_on: Sequence[str],
right_on: Sequence[str],
suffix: str = "_right",
) -> Self: ...
def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
def partition_by(
self, by: Sequence[str], *, include_key: bool = True
) -> list[Self]: ...
Expand Down
Loading
Loading