Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- shape
- sort
- tail
- top_k
- to_arrow
- to_dict
- to_native
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/lazyframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- sink_parquet
- sort
- tail
- top_k
- to_native
- unique
- unpivot
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from narwhals.exceptions import ShapeError

if TYPE_CHECKING:
from collections.abc import Iterable
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -442,6 +443,20 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
validate_column_names=False,
)

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
if isinstance(reverse, bool):
order: Order = "ascending" if reverse else "descending"
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
else:
sorting = [
(key, "ascending" if is_ascending else "descending")
for key, is_ascending in zip_strict(by, reverse)
]
return self._with_native(
self.native.take(pc.select_k_unstable(self.native, k, sorting)), # type: ignore[call-overload]
validate_column_names=False,
)

def to_pandas(self) -> pd.DataFrame:
return self.native.to_pandas()

Expand Down
18 changes: 17 additions & 1 deletion narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from narwhals.typing import CompliantLazyFrame

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -263,6 +263,22 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
self.native.sort_values(list(by), ascending=ascending, na_position=position)
)

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
df = self.native
schema = self.schema
by = list(by)
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
if reverse:
return self._with_native(df.nsmallest(k, by))
return self._with_native(df.nlargest(k, by))
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
return self._with_native(
df.sort_values(by, ascending=list(reverse)).head(
n=k, compute=False, npartitions=-1
)
)

def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
Expand Down
23 changes: 22 additions & 1 deletion narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -413,6 +413,27 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
)
return self._with_native(self.native.sort(*it))

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
_df = self.native
by = list(by)
if isinstance(reverse, bool):
descending = [not reverse] * len(by)
else:
descending = [not rev for rev in reverse]
expr = window_expression(
F("row_number"),
order_by=by,
descending=descending,
nulls_last=[True] * len(by),
)
condition = expr <= lit(k)
query = f"""
SELECT *
FROM _df
QUALIFY {condition}
""" # noqa: S608
return self._with_native(duckdb.sql(query))

def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset_ = subset if subset is not None else self.columns
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->

return self._with_native(self.native.order_by(*sort_cols))

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
if isinstance(reverse, bool):
reverse = [reverse] * len(list(by))
sort_cols = []

for is_reverse, by_col in zip_strict(reverse, by):
direction_fn = ibis.asc if is_reverse else ibis.desc
col = direction_fn(by_col, nulls_first=False)
sort_cols.append(cast("ir.Column", col))

return self._with_native(self.native.order_by(*sort_cols).head(k))

def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset_ = subset if subset is not None else self.columns
return self._with_native(self.native.drop_null(subset_))
Expand Down
12 changes: 12 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,18 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
validate_column_names=False,
)

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
df = self.native
schema = self.schema
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
if reverse:
return self._with_native(df.nsmallest(k, by))
return self._with_native(df.nlargest(k, by))
return self._with_native(
df.sort_values(list(by), ascending=reverse).head(k),
validate_column_names=False,
)

# --- convert ---
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
Expand Down
22 changes: 22 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from narwhals.exceptions import ColumnNotFoundError

if TYPE_CHECKING:
from collections.abc import Iterable
from types import ModuleType
from typing import Callable

Expand Down Expand Up @@ -197,6 +198,19 @@ def join(
)
)

def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.top_k(
k=k,
by=by,
descending=reverse, # type: ignore[call-arg]
)
)
return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse))

def unpivot(
self,
on: Sequence[str] | None,
Expand Down Expand Up @@ -547,6 +561,14 @@ def join(
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None

def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
try:
return super().top_k(k=k, by=by, reverse=reverse)
except Exception as e: # noqa: BLE001 # pragma: no cover
raise catch_polars_exception(e) from None


class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
sink_parquet: Method[None]
Expand Down
12 changes: 11 additions & 1 deletion narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -340,6 +340,16 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols))

def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
by = list(by)
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
sort_funcs = (
self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse
)
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols).limit(k))

def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset = list(subset) if subset else None
return self._with_native(self.native.dropna(subset=subset))
Expand Down
87 changes: 87 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ def sort(
self._compliant_frame.sort(*by, descending=descending, nulls_last=nulls_last)
)

def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False
) -> Self:
flatten_by = flatten([by])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check that if reverse is a sequence, and it's length is different than flatten_by, then an exception is raise? This guarantees that zip(by, reverse) at the compliant level is same as zip_strict.

From polars:

df = pl.DataFrame(
    {
        "a": ["a", "b", "a", "b", "b", "c"],
        "b": [2, 1, 1, 3, 2, 1],
    }
)

df.top_k(4, by=["b", "a"], reverse=[True])

ValueError: the length of reverse (1) does not match the length of by (2)

Copy link
Member

@FBruzzesi FBruzzesi Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raisadz I would still prefer to add a check at this level to also align the error with polars (notice that the output of flatten is a list anyway), but feel free to merge. We can follow up on it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there's some other places where this would be useful (like sort) so we could probably make a validation utility for this and use it in multiple places

return self._with_compliant(
self._compliant_frame.top_k(k, by=flatten_by, reverse=reverse)
)

def join(
self,
other: Self,
Expand Down Expand Up @@ -1747,6 +1755,43 @@ def sort(
"""
return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last)

def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False
) -> Self:
r"""Return the `k` largest rows.

Non-null elements are always preferred over null elements,
regardless of the value of reverse. The output is not guaranteed
to be in any particular order, sort the outputs afterwards if you wish the output to be sorted.

Arguments:
k: Number of rows to return.
by: Column(s) used to determine the top rows. Accepts expression input. Strings are parsed as column names.
reverse: Consider the k smallest elements of the by column(s) (instead of the k largest).
This can be specified per column by passing a sequence of booleans.

Returns:
The dataframe with the `k` largest rows.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
>>> df_native = pd.DataFrame(
... {"a": ["a", "b", "a", "b", None, "c"], "b": [2, 1, 1, 3, 2, 1]}
... )
>>> nw.from_native(df_native).top_k(4, by=["b", "a"])
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals DataFrame|
|------------------|
| a b |
| 3 b 3 |
| 0 a 2 |
| 4 None 2 |
| 5 c 1 |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return super().top_k(k, by=by, reverse=reverse)

def join(
self,
other: Self,
Expand Down Expand Up @@ -2983,6 +3028,48 @@ def sort(
"""
return super().sort(by, *more_by, descending=descending, nulls_last=nulls_last)

def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool] = False
) -> Self:
r"""Return the `k` largest rows.

Non-null elements are always preferred over null elements,
regardless of the value of reverse. The output is not guaranteed
to be in any particular order, sort the outputs afterwards if you wish the output to be sorted.

Arguments:
k: Number of rows to return.
by: Column(s) used to determine the top rows. Accepts expression input. Strings are parsed as column names.
reverse: Consider the k smallest elements of the by column(s) (instead of the k largest).
This can be specified per column by passing a sequence of booleans.

Returns:
The LazyFrame with the `k` largest rows.

Examples:
>>> import duckdb
>>> import narwhals as nw
>>> df_native = duckdb.sql(
... "SELECT * FROM VALUES ('a', 2), ('b', 1), ('a', 1), ('b', 3), (NULL, 2), ('c', 1) df(a, b)"
... )
>>> df = nw.from_native(df_native)
>>> df.top_k(4, by=["b", "a"])
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
|Narwhals LazyFrame |
|-------------------|
|β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”|
|β”‚ a β”‚ b β”‚|
|β”‚ varchar β”‚ int32 β”‚|
|β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€|
|β”‚ b β”‚ 3 β”‚|
|β”‚ a β”‚ 2 β”‚|
|β”‚ NULL β”‚ 2 β”‚|
|β”‚ c β”‚ 1 β”‚|
|β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return super().top_k(k, by=by, reverse=reverse)

def join(
self,
other: Self,
Expand Down
50 changes: 50 additions & 0 deletions tests/frame/top_k_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import pytest

import narwhals as nw
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data


def test_top_k(constructor: Constructor) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
# old polars versions do not sort nulls last
pytest.skip()
data = {"a": ["a", "f", "a", "d", "b", "c"], "b c": [None, None, 2, 3, 6, 1]}
df = nw.from_native(constructor(data))
result = df.top_k(4, by="b c")
expected = {"a": ["a", "b", "c", "d"], "b c": [2, 6, 1, 3]}
assert_equal_data(result.sort("a"), expected)
df = nw.from_native(constructor(data))
result = df.top_k(4, by="b c", reverse=True)
expected = {"a": ["a", "b", "c", "d"], "b c": [2, 6, 1, 3]}
assert_equal_data(result.sort(by="a"), expected)


def test_top_k_by_multiple(constructor: Constructor) -> None:
data = {
"a": ["a", "f", "a", "d", "b", "c"],
"b": [2, 2, 2, 3, 1, 1],
"sf_c": ["k", "d", "s", "a", "a", "r"],
}
df = nw.from_native(constructor(data))
result = df.top_k(4, by=["b", "sf_c"], reverse=True)
expected = {
"a": ["b", "f", "a", "c"],
"b": [1, 2, 2, 1],
"sf_c": ["a", "d", "k", "r"],
}
assert_equal_data(result.sort("sf_c"), expected)
data = {
"a": ["a", "f", "a", "d", "b", "c"],
"b": [2, 2, 2, 3, 1, 1],
"sf_c": ["k", "d", "s", "a", "a", "r"],
}
df = nw.from_native(constructor(data))
result = df.top_k(4, by=["b", "sf_c"], reverse=[False, True])
expected = {
"a": ["d", "f", "a", "a"],
"b": [3, 2, 2, 2],
"sf_c": ["a", "d", "k", "s"],
}
assert_equal_data(result.sort("sf_c"), expected)
Loading