diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index 97f5fb9842..d1c942191c 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -28,7 +28,7 @@ from narwhals._plan.common import ensure_list_str, temp from narwhals._plan.typing import NonCrossJoinStrategy, OneOrSeq from narwhals._utils import check_column_names_are_unique -from narwhals.typing import JoinStrategy, SingleColSelector +from narwhals.typing import AsofJoinStrategy, JoinStrategy, SingleColSelector if TYPE_CHECKING: from collections.abc import ( @@ -47,7 +47,12 @@ Aggregation as _Aggregation, ) from narwhals._plan.arrow.group_by import AggSpec - from narwhals._plan.arrow.typing import ArrowAny, JoinTypeSubset, ScalarAny + from narwhals._plan.arrow.typing import ( + ArrowAny, + ChunkedArrayAny, + JoinTypeSubset, + ScalarAny, + ) from narwhals._plan.typing import OneOrIterable, Seq from narwhals.typing import NonNestedLiteral @@ -278,6 +283,53 @@ def _hashjoin( return Decl("hashjoin", options, [_into_decl(left), _into_decl(right)]) +def _join_asof_suffix_collisions( + left: pa.Table, right: pa.Table, right_on: str, right_by: Sequence[str], suffix: str +) -> pa.Table: + """Adapted from [upstream] to avoid raising early. + + [upstream]: https://github.com/apache/arrow/blob/9b03118e834dfdaa0cf9e03595477b499252a9cb/python/pyarrow/acero.py#L306-L316 + """ + right_names = right.schema.names + allowed = {right_on, *right_by} + if collisions := set(right_names).difference(allowed).intersection(left.schema.names): + renamed = [f"{nm}{suffix}" if nm in collisions else nm for nm in right_names] + return right.rename_columns(renamed) + return right + + +def _join_asof_strategy_to_tolerance( + left_on: ChunkedArrayAny, right_on: ChunkedArrayAny, /, strategy: AsofJoinStrategy +) -> int: + """Calculate the **required** `tolerance` argument, from `*on` values and strategy. + + For both `polars` and `pandas` this is optional and in `narwhals` it isn't supported (yet). + + So we need to get the lowest/highest value for a match and use that for a similar default. + + `"backward"`: + + tolerance <= right_on - left_on <= 0 + + `"forward"`: + + 0 <= right_on - left_on <= tolerance + + Note: + `tolerance` is interpreted in the same units as the `on` keys. + """ + import narwhals._plan.arrow.functions as fn + + if strategy == "nearest": + msg = "Only 'backward' and 'forward' strategies are currently supported for `pyarrow`" + raise NotImplementedError(msg) + lower = fn.min_horizontal(fn.min_(left_on), fn.min_(right_on)) + upper = fn.max_horizontal(fn.max_(left_on), fn.max_(right_on)) + scalar = fn.sub(lower, upper) if strategy == "backward" else fn.sub(upper, lower) + tolerance: int = fn.cast(scalar, fn.I64).as_py() + return tolerance + + def declare(*declarations: Decl) -> Decl: """Compose one or more `Declaration` nodes for execution as a pipeline.""" if len(declarations) == 1: @@ -439,6 +491,31 @@ def join_cross_tables( return collect(decl, ensure_unique_column_names=True).remove_column(0) +def join_asof_tables( + left: pa.Table, + right: pa.Table, + left_on: str, + right_on: str, + *, + left_by: Sequence[str] = (), + right_by: Sequence[str] = (), + strategy: AsofJoinStrategy = "backward", + suffix: str = "_right", +) -> pa.Table: + """Perform an inexact join between two tables, using the nearest key.""" + right = _join_asof_suffix_collisions(left, right, right_on, right_by, suffix=suffix) + tolerance = _join_asof_strategy_to_tolerance( + left.column(left_on), right.column(right_on), strategy + ) + lb: list[Any] = [] if not left_by else list(left_by) + rb: list[Any] = [] if not right_by else list(right_by) + join_opts = pac.AsofJoinNodeOptions( + left_on=left_on, right_on=right_on, left_by=lb, right_by=rb, tolerance=tolerance + ) + inputs = [table_source(left), table_source(right)] + return Decl("asofjoin", join_opts, inputs).to_table() + + def _add_column_table( native: pa.Table, index: int, name: str, values: IntoExpr | ArrowAny ) -> pa.Table: diff --git a/narwhals/_plan/arrow/common.py b/narwhals/_plan/arrow/common.py index 22fd016574..5a6a47bd37 100644 --- a/narwhals/_plan/arrow/common.py +++ b/narwhals/_plan/arrow/common.py @@ -31,6 +31,10 @@ class ArrowFrameSeries(Generic[NativeT]): _native: NativeT _version: Version + # NOTE: Aliases to integrate with `@requires.backend_version` + _backend_version = compat.BACKEND_VERSION + _implementation = implementation + @property def native(self) -> NativeT: return self._native diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 31e220a2a5..b80b59e039 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -20,7 +20,7 @@ from narwhals._plan.compliant.typing import LazyFrameAny, namespace from narwhals._plan.exceptions import shape_error from narwhals._plan.expressions import NamedIR, named_ir -from narwhals._utils import Version, generate_repr +from narwhals._utils import Version, generate_repr, requires from narwhals.schema import Schema if TYPE_CHECKING: @@ -37,7 +37,7 @@ from narwhals._plan.typing import NonCrossJoinStrategy from narwhals._typing import _LazyAllowedImpl from narwhals.dtypes import DType - from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy + from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy Incomplete: TypeAlias = Any @@ -308,6 +308,31 @@ def join_inner(self, other: Self, on: list[str], /) -> Self: """Less flexible, but more direct equivalent to join(how="inner", left_on=...)`.""" return self._with_native(acero.join_inner_tables(self.native, other.native, on)) + @requires.backend_version((16,)) + def join_asof( + self, + other: Self, + *, + left_on: str, + right_on: str, + left_by: Sequence[str] = (), + right_by: Sequence[str] = (), + strategy: AsofJoinStrategy = "backward", + suffix: str = "_right", + ) -> Self: + return self._with_native( + acero.join_asof_tables( + self.native, + other.native, + left_on, + right_on, + left_by=left_by, + right_by=right_by, + strategy=strategy, + suffix=suffix, + ) + ) + def _filter(self, predicate: Predicate | acero.Expr) -> Self: mask: Incomplete = predicate return self._with_native(self.native.filter(mask)) diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 7a935f0c8e..0bfbaf2517 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -45,7 +45,7 @@ from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl from narwhals._utils import Implementation, Version from narwhals.dtypes import DType - from narwhals.typing import IntoSchema, PivotAgg, UniqueKeepStrategy + from narwhals.typing import AsofJoinStrategy, IntoSchema, PivotAgg, UniqueKeepStrategy Incomplete: TypeAlias = Any @@ -72,6 +72,27 @@ def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... def explode(self, subset: Sequence[str], options: ExplodeOptions) -> Self: ... # Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around def filter(self, predicate: NamedIR, /) -> Self: ... + def join( + self, + other: Self, + *, + how: NonCrossJoinStrategy, + left_on: Sequence[str], + right_on: Sequence[str], + suffix: str = "_right", + ) -> Self: ... + def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... + def join_asof( + self, + other: Self, + *, + left_on: str, + right_on: str, + left_by: Sequence[str] = (), # https://github.com/pola-rs/polars/issues/18496 + right_by: Sequence[str] = (), + strategy: AsofJoinStrategy = "backward", + suffix: str = "_right", + ) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... @property def schema(self) -> Mapping[str, DType]: ... @@ -213,16 +234,6 @@ def group_by_resolver(self, resolver: GroupByResolver, /) -> DataFrameGroupBy[Se def filter(self, predicate: NamedIR, /) -> Self: ... def iter_columns(self) -> Iterator[SeriesT]: ... - def join( - self, - other: Self, - *, - how: NonCrossJoinStrategy, - left_on: Sequence[str], - right_on: Sequence[str], - suffix: str = "_right", - ) -> Self: ... - def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... def partition_by( self, by: Sequence[str], *, include_key: bool = True ) -> list[Self]: ... diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 134e46ec1e..733a7b076b 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -32,6 +32,7 @@ from narwhals.exceptions import InvalidOperationError, ShapeError from narwhals.schema import Schema from narwhals.typing import ( + AsofJoinStrategy, EagerAllowed, FileSource, IntoBackend, @@ -94,6 +95,21 @@ def __repr__(self) -> str: def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: self._compliant = compliant + def _unwrap_compliant(self, other: Self | Any, /) -> Incomplete: + """Return the `CompliantFrame` that backs `other` if it matches self. + + - Rejects (`DataFrame`, `LazyFrame`) and (`LazyFrame`, `DataFrame`) + - Rejects mixed backends like (`DataFrame[pa.Table]`, `DataFrame[pd.DataFrame]`) + """ + if isinstance(other, type(self)): + compliant = other._compliant + if isinstance(compliant, type(self._compliant)): + return compliant + msg = f"Expected {qualified_type_name(self._compliant)!r}, got {qualified_type_name(compliant)!r}" + raise NotImplementedError(msg) + msg = f"Expected `other` to be a {qualified_type_name(self)!r}, got: {qualified_type_name(other)!r}" # pragma: no cover + raise TypeError(msg) # pragma: no cover + def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: return type(self)(compliant) @@ -191,6 +207,67 @@ def with_row_index( by_names = expand_selector_irs_names(by_selectors, schema=self, require_any=True) return self._with_compliant(self._compliant.with_row_index_by(name, by_names)) + def join( + self, + other: Incomplete, + on: str | Sequence[str] | None = None, + how: JoinStrategy = "inner", + *, + left_on: str | Sequence[str] | None = None, + right_on: str | Sequence[str] | None = None, + suffix: str = "_right", + ) -> Self: + left = self._compliant + right: CompliantFrame[Any, NativeFrameT_co] = self._unwrap_compliant(other) + how = _validate_join_strategy(how) + if how == "cross": + if left_on is not None or right_on is not None or on is not None: + msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" + raise ValueError(msg) + return self._with_compliant(left.join_cross(right, suffix=suffix)) + left_on, right_on = normalize_join_on(on, how, left_on, right_on) + return self._with_compliant( + left.join(right, how=how, left_on=left_on, right_on=right_on, suffix=suffix) + ) + + def join_asof( + self, + other: Incomplete, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | Sequence[str] | None = None, + by_right: str | Sequence[str] | None = None, + by: str | Sequence[str] | None = None, + strategy: AsofJoinStrategy = "backward", + suffix: str = "_right", + ) -> Self: + left = self._compliant + right: CompliantFrame[Any, NativeFrameT_co] = self._unwrap_compliant(other) + strategy = _validate_join_asof_strategy(strategy) + left_on_, right_on_ = normalize_join_asof_on(left_on, right_on, on) + if by_left or by_right or by: + left_by, right_by = normalize_join_asof_by(by_left, by_right, by) + result = left.join_asof( + right, + left_on=left_on_, + right_on=right_on_, + left_by=left_by, + right_by=right_by, + strategy=strategy, + suffix=suffix, + ) + else: + result = left.join_asof( + right, + left_on=left_on_, + right_on=right_on_, + strategy=strategy, + suffix=suffix, + ) + return self._with_compliant(result) + def explode( self, columns: OneOrIterable[ColumnNameOrSelector], @@ -432,16 +509,33 @@ def join( right_on: str | Sequence[str] | None = None, suffix: str = "_right", ) -> Self: - left, right = self._compliant, other._compliant - how = _validate_join_strategy(how) - if how == "cross": - if left_on is not None or right_on is not None or on is not None: - msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" - raise ValueError(msg) - return self._with_compliant(left.join_cross(right, suffix=suffix)) - left_on, right_on = normalize_join_on(on, how, left_on, right_on) - return self._with_compliant( - left.join(right, how=how, left_on=left_on, right_on=right_on, suffix=suffix) + return super().join( + other, how=how, left_on=left_on, right_on=right_on, on=on, suffix=suffix + ) + + def join_asof( + self, + other: Self, + *, + left_on: str | None = None, + right_on: str | None = None, + on: str | None = None, + by_left: str | Sequence[str] | None = None, + by_right: str | Sequence[str] | None = None, + by: str | Sequence[str] | None = None, + strategy: AsofJoinStrategy = "backward", + suffix: str = "_right", + ) -> Self: + return super().join_asof( + other, + left_on=left_on, + right_on=right_on, + on=on, + by_left=by_left, + by_right=by_right, + by=by, + strategy=strategy, + suffix=suffix, ) def filter( @@ -650,6 +744,10 @@ def _is_unique_keep_strategy(obj: Any) -> TypeIs[UniqueKeepStrategy]: return obj in {"any", "first", "last", "none"} +def _is_join_asof_strategy(obj: Any) -> TypeIs[AsofJoinStrategy]: + return obj in {"backward", "forward", "nearest"} + + def _validate_join_strategy(how: str, /) -> JoinStrategy: if _is_join_strategy(how): return how @@ -657,6 +755,13 @@ def _validate_join_strategy(how: str, /) -> JoinStrategy: raise NotImplementedError(msg) +def _validate_join_asof_strategy(strategy: str, /) -> AsofJoinStrategy: + if _is_join_asof_strategy(strategy): + return strategy + msg = f"Only the following join strategies are supported: {get_args(AsofJoinStrategy)}; found '{strategy}'." + raise NotImplementedError(msg) + + def _validate_unique_keep_strategy(keep: str, /) -> UniqueKeepStrategy: if _is_unique_keep_strategy(keep): return keep @@ -671,7 +776,7 @@ def normalize_join_on( right_on: OneOrIterable[str] | None, /, ) -> tuple[Seq[str], Seq[str]]: - """Reduce the 3 potential key (`on*`) arguments to 2. + """Reduce the 3 potential key (`*on`) arguments to 2. Ensures the keys spelling is compatible with the join strategy. """ @@ -692,6 +797,44 @@ def normalize_join_on( return on, on +def normalize_join_asof_on( + left_on: str | None, right_on: str | None, on: str | None +) -> tuple[str, str]: + """Reduce the 3 potential `join_asof` (`*on`) arguments to 2.""" + if on is None: + if left_on is None or right_on is None: + msg = "Either (`left_on` and `right_on`) or `on` keys should be specified." + raise ValueError(msg) + return left_on, right_on + if left_on is not None or right_on is not None: + msg = "If `on` is specified, `left_on` and `right_on` should be None." + raise ValueError(msg) + return on, on + + +def normalize_join_asof_by( + by_left: str | Sequence[str] | None, + by_right: str | Sequence[str] | None, + by: str | Sequence[str] | None, +) -> tuple[Seq[str], Seq[str]]: + """Reduce the 3 potential `join_asof` (`by*`) arguments to 2.""" + if by is None: + if by_left and by_right: + left_by = ensure_seq_str(by_left) + right_by = ensure_seq_str(by_right) + if len(left_by) != len(right_by): + msg = "`by_left` and `by_right` must have the same length." + raise ValueError(msg) + return left_by, right_by + msg = "Can not specify only `by_left` or `by_right`, you need to specify both." + raise ValueError(msg) + if by_left or by_right: + msg = "If `by` is specified, `by_left` and `by_right` should be None." + raise ValueError(msg) + by_ = ensure_seq_str(by) # pragma: no cover + return by_, by_ # pragma: no cover + + def normalize_pivot_args( on: OneOrIterable[str], *, diff --git a/tests/plan/join_test.py b/tests/plan/join_test.py index 8ba63db9b9..5b3cb9d693 100644 --- a/tests/plan/join_test.py +++ b/tests/plan/join_test.py @@ -1,32 +1,47 @@ from __future__ import annotations +import datetime as dt from typing import TYPE_CHECKING, Any, Literal, TypedDict import pytest import narwhals._plan as nwp from narwhals.exceptions import DuplicateError -from tests.plan.utils import assert_equal_data, dataframe +from tests.plan.utils import assert_equal_data, dataframe, re_compile +from tests.utils import PYARROW_VERSION if TYPE_CHECKING: from collections.abc import Sequence from typing_extensions import TypeAlias - from narwhals.typing import JoinStrategy + from narwhals.typing import AsofJoinStrategy, JoinStrategy from tests.conftest import Data - On: TypeAlias = "str | Sequence[str] | None" +By: TypeAlias = "str | Sequence[str] | None" +"""The type of `{by,by_left,by_right}`.""" -class Keywords(TypedDict, total=False): +class AsofKwds(TypedDict, total=False): + """Arguments for `DataFrame.asof`.""" + + on: str | None + left_on: str | None + right_on: str | None + suffix: str + by_left: By + by_right: By + by: By + + +class JoinKwds(TypedDict, total=False): """Arguments for `DataFrame.join`.""" - on: On - how: JoinStrategy - left_on: On - right_on: On + on: str | Sequence[str] | None + left_on: str | Sequence[str] | None + right_on: str | Sequence[str] | None suffix: str + how: JoinStrategy @pytest.fixture @@ -67,9 +82,9 @@ def data_a_only(data: Data) -> Data: LEFT_DATA_1, RIGHT_DATA_1, EXPECTED_DATA_1, - Keywords(left_on=["id"], right_on=["id"]), + JoinKwds(left_on=["id"], right_on=["id"]), ), - (LEFT_DATA_1, RIGHT_DATA_1, EXPECTED_DATA_1, Keywords(on="id")), + (LEFT_DATA_1, RIGHT_DATA_1, EXPECTED_DATA_1, JoinKwds(on="id")), ( { "id": [1, 2, 3, 4], @@ -90,13 +105,13 @@ def data_a_only(data: Data) -> Data: "year_foo": [None, 2021, 2022, 2023, 2024], "value2": [None, 500, 600, 700, 800], }, - Keywords(left_on=["id", "year"], right_on=["id", "year_foo"]), + JoinKwds(left_on=["id", "year"], right_on=["id", "year_foo"]), ), ], ids=["left_on-right_on-identical", "on", "left_on-right_on-different"], ) def test_join_full( - left_data: Data, right_data: Data, expected: Data, kwds: Keywords + left_data: Data, right_data: Data, expected: Data, kwds: JoinKwds ) -> None: kwds["how"] = "full" result = ( @@ -120,8 +135,8 @@ def test_join_inner_x2_duplicate(data_indexed: Data) -> None: df.join(df, "a").join(df, "a") -@pytest.mark.parametrize("kwds", [Keywords(left_on="a", right_on="a"), Keywords(on="a")]) -def test_join_inner_single_key(data_indexed: Data, kwds: Keywords) -> None: +@pytest.mark.parametrize("kwds", [JoinKwds(left_on="a", right_on="a"), JoinKwds(on="a")]) +def test_join_inner_single_key(data_indexed: Data, kwds: JoinKwds) -> None: df = dataframe(data_indexed) result = df.join(df, **kwds).sort("idx").drop("idx_right") expected = { @@ -136,9 +151,9 @@ def test_join_inner_single_key(data_indexed: Data, kwds: Keywords) -> None: @pytest.mark.parametrize( - "kwds", [Keywords(left_on=["a", "b"], right_on=["a", "b"]), Keywords(on=["a", "b"])] + "kwds", [JoinKwds(left_on=["a", "b"], right_on=["a", "b"]), JoinKwds(on=["a", "b"])] ) -def test_join_inner_two_keys(data_indexed: Data, kwds: Keywords) -> None: +def test_join_inner_two_keys(data_indexed: Data, kwds: JoinKwds) -> None: df = dataframe(data_indexed) result = df.join(df, **kwds).sort("idx").drop("idx_right") expected = { @@ -184,7 +199,7 @@ def test_join_left_multiple_column() -> None: ("kwds", "expected"), [ ( - Keywords(left_on="b", right_on="c"), + JoinKwds(left_on="b", right_on="c"), { "a": [1, 2, 3], "b": [4, 5, 6], @@ -195,7 +210,7 @@ def test_join_left_multiple_column() -> None: }, ), ( - Keywords(left_on="a", right_on="d"), + JoinKwds(left_on="a", right_on="d"), { "a": [1, 2, 3], "b": [4, 5, 6], @@ -207,7 +222,7 @@ def test_join_left_multiple_column() -> None: ), ], ) -def test_join_left_overlapping_column(kwds: Keywords, expected: dict[str, Any]) -> None: +def test_join_left_overlapping_column(kwds: JoinKwds, expected: dict[str, Any]) -> None: kwds["how"] = "left" source = { "a": [1.0, 2.0, 3.0], @@ -273,22 +288,22 @@ def test_join_filter( EITHER_LR_OR_ON = r"`left_on` and `right_on`.+or.+`on`" ONLY_ON = r"`on` is specified.+`left_on` and `right_on`.+be.+None" -SAME_LENGTH = r"`left_on` and `right_on`.+same length" +SAME_LENGTH_ON = r"`left_on` and `right_on`.+same length" @pytest.mark.parametrize( ("kwds", "message"), [ - (Keywords(), EITHER_LR_OR_ON), - (Keywords(left_on="a"), EITHER_LR_OR_ON), - (Keywords(right_on="a"), EITHER_LR_OR_ON), - (Keywords(on="a", right_on="a"), ONLY_ON), - (Keywords(left_on=["a", "b"], right_on="a"), SAME_LENGTH), + (JoinKwds(), EITHER_LR_OR_ON), + (JoinKwds(left_on="a"), EITHER_LR_OR_ON), + (JoinKwds(right_on="a"), EITHER_LR_OR_ON), + (JoinKwds(on="a", right_on="a"), ONLY_ON), + (JoinKwds(left_on=["a", "b"], right_on="a"), SAME_LENGTH_ON), ], ) @pytest.mark.parametrize("how", ["inner", "left", "semi", "anti"]) def test_join_keys_exceptions( - how: JoinStrategy, kwds: Keywords, message: str, data: Data + how: JoinStrategy, kwds: JoinKwds, message: str, data: Data ) -> None: df = dataframe(data) kwds["how"] = how @@ -299,13 +314,13 @@ def test_join_keys_exceptions( @pytest.mark.parametrize( "kwds", [ - Keywords(left_on="a"), - Keywords(on="a"), - Keywords(right_on="a"), - Keywords(left_on="a", right_on="a"), + JoinKwds(left_on="a"), + JoinKwds(on="a"), + JoinKwds(right_on="a"), + JoinKwds(left_on="a", right_on="a"), ], ) -def test_join_cross_keys_exceptions(kwds: Keywords, data_a_only: Data) -> None: +def test_join_cross_keys_exceptions(kwds: JoinKwds, data_a_only: Data) -> None: df = dataframe(data_a_only) kwds["how"] = "cross" with pytest.raises(ValueError, match=r"not.+ `left_on`.+`right_on`.+`on`.+cross"): @@ -319,3 +334,165 @@ def test_join_not_implemented(data_a_only: Data) -> None: ) with pytest.raises(NotImplementedError, match=(pattern)): df.join(df, left_on="a", right_on="a", how="right") # type: ignore[arg-type] + + +# NOTE: move `join_asof` to a different file later + + +PYARROW_HAS_JOIN_ASOF = PYARROW_VERSION >= (16, 0, 0) + + +def require_pyarrow_16( + df: nwp.DataFrame[Any, Any], request: pytest.FixtureRequest +) -> None: + request.applymarker( + pytest.mark.xfail( + (df.implementation.is_pyarrow() and not PYARROW_HAS_JOIN_ASOF), + reason="pyarrow too old for `join_asof` support", + raises=NotImplementedError, + ) + ) + + +XFAIL_NEAREST = pytest.mark.xfail( + PYARROW_HAS_JOIN_ASOF, + reason="Only 'backward' and 'forward' strategies are currently supported for `pyarrow`", + raises=NotImplementedError, +) + + +@pytest.mark.parametrize( + ("strategy", "expected_values"), + [ + ("backward", [1, 3, 7]), + ("forward", [1, 6, None]), + pytest.param("nearest", [1, 6, 7], marks=XFAIL_NEAREST), + ], + ids=str, +) +@pytest.mark.parametrize("kwds", [AsofKwds(left_on="a", right_on="a"), AsofKwds(on="a")]) +def test_join_asof_numeric( + strategy: AsofJoinStrategy, + expected_values: list[Any], + request: pytest.FixtureRequest, + kwds: AsofKwds, +) -> None: + left = {"a": [1, 5, 10], "val": ["a", "b", "c"]} + right = {"a": [1, 2, 3, 6, 7], "val": [1, 2, 3, 6, 7]} + expected = left | {"val_right": expected_values} + df = dataframe(left).sort("a") + require_pyarrow_16(df, request) + df_right = dataframe(right).sort("a") + result = df.join_asof(df_right, **kwds, strategy=strategy).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("strategy", "expected_values"), + [ + ("backward", [4164, 4566, 4696]), + ("forward", [4411, 4696, 4696]), + pytest.param("nearest", [4164, 4696, 4696], marks=XFAIL_NEAREST), + ], + ids=str, +) +@pytest.mark.parametrize( + "kwds", [AsofKwds(left_on="ts", right_on="ts"), AsofKwds(on="ts")] +) +def test_join_asof_time( + strategy: AsofJoinStrategy, + expected_values: list[float], + request: pytest.FixtureRequest, + kwds: AsofKwds, +) -> None: + left = { + "ts": [dt.datetime(2016, 3, 1), dt.datetime(2018, 8, 1), dt.datetime(2019, 1, 1)], + "pop": [82.19, 82.66, 83.12], + } + right = { + "ts": [ + dt.datetime(2016, 1, 1), + dt.datetime(2017, 1, 1), + dt.datetime(2018, 1, 1), + dt.datetime(2019, 1, 1), + dt.datetime(2020, 1, 1), + ], + "gdp": [4164, 4411, 4566, 4696, 4827], + } + expected = left | {"gdp": expected_values} + df = dataframe(left).sort("ts") + require_pyarrow_16(df, request) + df_right = dataframe(right).sort("ts") + result = df.join_asof(df_right, **kwds, strategy=strategy).sort("ts") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + "kwds", [AsofKwds(on="a", by_left="b", by_right="b"), AsofKwds(on="a", by="b")] +) +def test_join_asof_by(request: pytest.FixtureRequest, kwds: AsofKwds) -> None: + left = {"a": [1, 5, 7, 10], "b": ["D", "D", "C", "A"], "c": [9, 2, 1, 1]} + right = {"a": [1, 4, 5, 8], "b": ["D", "D", "A", "F"], "d": [1, 3, 4, 1]} + expected = { + "a": [1, 5, 7, 10], + "b": ["D", "D", "C", "A"], + "c": [9, 2, 1, 1], + "d": [1, 3, None, 4], + } + df = dataframe(left).sort("a") + require_pyarrow_16(df, request) + df_right = dataframe(right).sort("a") + result = df.join_asof(df_right, **kwds).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("kwds", [AsofKwds(left_on="a", right_on="a", suffix="_y")]) +def test_join_asof_suffix(request: pytest.FixtureRequest, kwds: AsofKwds) -> None: + left = {"a": [1, 5, 10], "val": ["a", "b", "c"]} + right = {"a": [1, 2, 3, 6, 7], "val": [1, 2, 3, 6, 7]} + expected = {"a": [1, 5, 10], "val": ["a", "b", "c"], "val_y": [1, 3, 7]} + df = dataframe(left).sort("a") + require_pyarrow_16(df, request) + df_right = dataframe(right).sort("a") + result = df.join_asof(df_right, **kwds).sort("a") + assert_equal_data(result, expected) + + +@pytest.mark.parametrize("strategy", ["back", "furthest"]) +def test_join_asof_not_implemented(strategy: str, data: Data) -> None: + df = dataframe(data) + pattern = re_compile( + rf"supported.+'backward', 'forward', 'nearest'.+ found '{strategy}'" + ) + with pytest.raises(NotImplementedError, match=pattern): + df.join_asof(df, left_on="a", right_on="a", strategy=strategy) # type: ignore[arg-type] + + +EITHER_LR_OR_BY = r"If.+by.+by_left.+by_right.+should be None" +SAME_LENGTH_BY = r"by_left.+by_right.+same.+length" +BOTH_BY = r"not.+by_left.+or.+by_right.+need.+both" +ON = "a" +BY = "b" + + +@pytest.mark.parametrize( + ("kwds", "message"), + [ + (AsofKwds(), EITHER_LR_OR_ON), + (AsofKwds(left_on=ON), EITHER_LR_OR_ON), + (AsofKwds(right_on=ON), EITHER_LR_OR_ON), + (AsofKwds(on=ON, right_on=ON), ONLY_ON), + (AsofKwds(on=ON, left_on=ON, right_on=ON), ONLY_ON), + (AsofKwds(on=ON, left_on=ON), ONLY_ON), + (AsofKwds(on=ON, by=BY, by_left=BY, by_right=BY), EITHER_LR_OR_BY), + (AsofKwds(on=ON, by=BY, by_left=BY), EITHER_LR_OR_BY), + (AsofKwds(on=ON, by=BY, by_right=BY), EITHER_LR_OR_BY), + (AsofKwds(on=ON, by_left=[ON, BY], by_right=BY), SAME_LENGTH_BY), + (AsofKwds(on=ON, by_left=BY), BOTH_BY), + (AsofKwds(on=ON, by_right=BY), BOTH_BY), + ], +) +def test_join_asof_invalid(data: Data, kwds: AsofKwds, message: str) -> None: + df = dataframe(data) + with pytest.raises(ValueError, match=message): + df.join_asof(df, **kwds)