diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 9f11107151..1524464c92 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -43,6 +43,7 @@ - shape - sort - tail + - top_k - to_arrow - to_dict - to_native diff --git a/docs/api-reference/lazyframe.md b/docs/api-reference/lazyframe.md index 4eea39ef1e..b27800d8d9 100644 --- a/docs/api-reference/lazyframe.md +++ b/docs/api-reference/lazyframe.md @@ -25,6 +25,7 @@ - sink_parquet - sort - tail + - top_k - to_native - unique - unpivot diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 38ad41a022..4de78d2792 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -27,6 +27,7 @@ from narwhals.exceptions import ShapeError if TYPE_CHECKING: + from collections.abc import Iterable from io import BytesIO from pathlib import Path from types import ModuleType @@ -442,6 +443,20 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> validate_column_names=False, ) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + if isinstance(reverse, bool): + order: Order = "ascending" if reverse else "descending" + sorting: list[tuple[str, Order]] = [(key, order) for key in by] + else: + sorting = [ + (key, "ascending" if is_ascending else "descending") + for key, is_ascending in zip_strict(by, reverse) + ] + return self._with_native( + self.native.take(pc.select_k_unstable(self.native, k, sorting)), # type: ignore[call-overload] + validate_column_names=False, + ) + def to_pandas(self) -> pd.DataFrame: return self.native.to_pandas() diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index aeb8d05d16..2699f405cb 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -22,7 +22,7 @@ from narwhals.typing import CompliantLazyFrame if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from io import BytesIO from pathlib import Path from types import ModuleType @@ -263,6 +263,22 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> self.native.sort_values(list(by), ascending=ascending, na_position=position) ) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + df = self.native + schema = self.schema + by = list(by) + if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by): + if reverse: + return self._with_native(df.nsmallest(k, by)) + return self._with_native(df.nlargest(k, by)) + if isinstance(reverse, bool): + reverse = [reverse] * len(by) + return self._with_native( + df.sort_values(by, ascending=list(reverse)).head( + n=k, compute=False, npartitions=-1 + ) + ) + def _join_inner( self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str ) -> dd.DataFrame: diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 47dda337f7..848bd9fe83 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -32,7 +32,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from io import BytesIO from pathlib import Path from types import ModuleType @@ -413,6 +413,27 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> ) return self._with_native(self.native.sort(*it)) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + _df = self.native + by = list(by) + if isinstance(reverse, bool): + descending = [not reverse] * len(by) + else: + descending = [not rev for rev in reverse] + expr = window_expression( + F("row_number"), + order_by=by, + descending=descending, + nulls_last=[True] * len(by), + ) + condition = expr <= lit(k) + query = f""" + SELECT * + FROM _df + QUALIFY {condition} + """ # noqa: S608 + return self._with_native(duckdb.sql(query)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: subset_ = subset if subset is not None else self.columns keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_)) diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 72e9adcb7c..7d3369ca66 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -348,6 +348,18 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> return self._with_native(self.native.order_by(*sort_cols)) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + if isinstance(reverse, bool): + reverse = [reverse] * len(list(by)) + sort_cols = [] + + for is_reverse, by_col in zip_strict(reverse, by): + direction_fn = ibis.asc if is_reverse else ibis.desc + col = direction_fn(by_col, nulls_first=False) + sort_cols.append(cast("ir.Column", col)) + + return self._with_native(self.native.order_by(*sort_cols).head(k)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: subset_ = subset if subset is not None else self.columns return self._with_native(self.native.drop_null(subset_)) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 3a258fd9f5..a5995b8145 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -490,6 +490,18 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> validate_column_names=False, ) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + df = self.native + schema = self.schema + if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by): + if reverse: + return self._with_native(df.nsmallest(k, by)) + return self._with_native(df.nlargest(k, by)) + return self._with_native( + df.sort_values(list(by), ascending=reverse).head(k), + validate_column_names=False, + ) + # --- convert --- def collect( self, backend: _EagerAllowedImpl | None, **kwargs: Any diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 1adbf1196f..c9e22f48f9 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -29,6 +29,7 @@ from narwhals.exceptions import ColumnNotFoundError if TYPE_CHECKING: + from collections.abc import Iterable from types import ModuleType from typing import Callable @@ -197,6 +198,19 @@ def join( ) ) + def top_k( + self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] + ) -> Self: + if self._backend_version < (1, 0, 0): + return self._with_native( + self.native.top_k( + k=k, + by=by, + descending=reverse, # type: ignore[call-arg] + ) + ) + return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse)) + def unpivot( self, on: Sequence[str] | None, @@ -547,6 +561,14 @@ def join( except Exception as e: # noqa: BLE001 raise catch_polars_exception(e) from None + def top_k( + self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] + ) -> Self: + try: + return super().top_k(k=k, by=by, reverse=reverse) + except Exception as e: # noqa: BLE001 # pragma: no cover + raise catch_polars_exception(e) from None + class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]): sink_parquet: Method[None] diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 79c409100f..c1ab4ec1fb 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -27,7 +27,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from io import BytesIO from pathlib import Path from types import ModuleType @@ -340,6 +340,16 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)] return self._with_native(self.native.sort(*sort_cols)) + def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self: + by = list(by) + if isinstance(reverse, bool): + reverse = [reverse] * len(by) + sort_funcs = ( + self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse + ) + sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)] + return self._with_native(self.native.sort(*sort_cols).limit(k)) + def drop_nulls(self, subset: Sequence[str] | None) -> Self: subset = list(subset) if subset else None return self._with_native(self.native.dropna(subset=subset)) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 19def52e03..88745b59f4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -245,6 +245,14 @@ def sort( self._compliant_frame.sort(*by, descending=descending, nulls_last=nulls_last) ) + def top_k( + self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False + ) -> Self: + flatten_by = flatten([by]) + return self._with_compliant( + self._compliant_frame.top_k(k, by=flatten_by, reverse=reverse) + ) + def join( self, other: Self, @@ -1747,6 +1755,43 @@ def sort( """ return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last) + def top_k( + self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False + ) -> Self: + r"""Return the `k` largest rows. + + Non-null elements are always preferred over null elements, + regardless of the value of reverse. The output is not guaranteed + to be in any particular order, sort the outputs afterwards if you wish the output to be sorted. + + Arguments: + k: Number of rows to return. + by: Column(s) used to determine the top rows. Accepts expression input. Strings are parsed as column names. + reverse: Consider the k smallest elements of the by column(s) (instead of the k largest). + This can be specified per column by passing a sequence of booleans. + + Returns: + The dataframe with the `k` largest rows. + + Examples: + >>> import pandas as pd + >>> import narwhals as nw + >>> df_native = pd.DataFrame( + ... {"a": ["a", "b", "a", "b", None, "c"], "b": [2, 1, 1, 3, 2, 1]} + ... ) + >>> nw.from_native(df_native).top_k(4, by=["b", "a"]) + ┌──────────────────┐ + |Narwhals DataFrame| + |------------------| + | a b | + | 3 b 3 | + | 0 a 2 | + | 4 None 2 | + | 5 c 1 | + └──────────────────┘ + """ + return super().top_k(k, by=by, reverse=reverse) + def join( self, other: Self, @@ -2983,6 +3028,48 @@ def sort( """ return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last) + def top_k( + self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False + ) -> Self: + r"""Return the `k` largest rows. + + Non-null elements are always preferred over null elements, + regardless of the value of reverse. The output is not guaranteed + to be in any particular order, sort the outputs afterwards if you wish the output to be sorted. + + Arguments: + k: Number of rows to return. + by: Column(s) used to determine the top rows. Accepts expression input. Strings are parsed as column names. + reverse: Consider the k smallest elements of the by column(s) (instead of the k largest). + This can be specified per column by passing a sequence of booleans. + + Returns: + The LazyFrame with the `k` largest rows. + + Examples: + >>> import duckdb + >>> import narwhals as nw + >>> df_native = duckdb.sql( + ... "SELECT * FROM VALUES ('a', 2), ('b', 1), ('a', 1), ('b', 3), (NULL, 2), ('c', 1) df(a, b)" + ... ) + >>> df = nw.from_native(df_native) + >>> df.top_k(4, by=["b", "a"]) + ┌───────────────────┐ + |Narwhals LazyFrame | + |-------------------| + |┌─────────┬───────┐| + |│ a │ b │| + |│ varchar │ int32 │| + |├─────────┼───────┤| + |│ b │ 3 │| + |│ a │ 2 │| + |│ NULL │ 2 │| + |│ c │ 1 │| + |└─────────┴───────┘| + └───────────────────┘ + """ + return super().top_k(k, by=by, reverse=reverse) + def join( self, other: Self, diff --git a/tests/frame/top_k_test.py b/tests/frame/top_k_test.py new file mode 100644 index 0000000000..bb995f44aa --- /dev/null +++ b/tests/frame/top_k_test.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +from tests.utils import POLARS_VERSION, Constructor, assert_equal_data + + +def test_top_k(constructor: Constructor) -> None: + if "polars" in str(constructor) and POLARS_VERSION < (1, 0): + # old polars versions do not sort nulls last + pytest.skip() + data = {"a": ["a", "f", "a", "d", "b", "c"], "b c": [None, None, 2, 3, 6, 1]} + df = nw.from_native(constructor(data)) + result = df.top_k(4, by="b c") + expected = {"a": ["a", "b", "c", "d"], "b c": [2, 6, 1, 3]} + assert_equal_data(result.sort("a"), expected) + df = nw.from_native(constructor(data)) + result = df.top_k(4, by="b c", reverse=True) + expected = {"a": ["a", "b", "c", "d"], "b c": [2, 6, 1, 3]} + assert_equal_data(result.sort(by="a"), expected) + + +def test_top_k_by_multiple(constructor: Constructor) -> None: + data = { + "a": ["a", "f", "a", "d", "b", "c"], + "b": [2, 2, 2, 3, 1, 1], + "sf_c": ["k", "d", "s", "a", "a", "r"], + } + df = nw.from_native(constructor(data)) + result = df.top_k(4, by=["b", "sf_c"], reverse=True) + expected = { + "a": ["b", "f", "a", "c"], + "b": [1, 2, 2, 1], + "sf_c": ["a", "d", "k", "r"], + } + assert_equal_data(result.sort("sf_c"), expected) + data = { + "a": ["a", "f", "a", "d", "b", "c"], + "b": [2, 2, 2, 3, 1, 1], + "sf_c": ["k", "d", "s", "a", "a", "r"], + } + df = nw.from_native(constructor(data)) + result = df.top_k(4, by=["b", "sf_c"], reverse=[False, True]) + expected = { + "a": ["d", "f", "a", "a"], + "b": [3, 2, 2, 2], + "sf_c": ["a", "d", "k", "s"], + } + assert_equal_data(result.sort("sf_c"), expected)