Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,6 @@ class LazyExpr( # type: ignore[misc]
tail: not_implemented = not_implemented()
mode: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
Expand Down
1 change: 1 addition & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,4 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:

list = not_implemented() # pyright: ignore[reportAssignmentType]
struct = not_implemented() # pyright: ignore[reportAssignmentType]
rank = not_implemented() # pyright: ignore[reportAssignmentType]
28 changes: 28 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,34 @@ def func(_input: duckdb.Expression) -> duckdb.Expression:

return self._with_callable(func)

def rank(
self,
method: Literal["average", "min", "max", "dense", "ordinal"],
*,
descending: bool,
) -> Self:
if method == "min":
func_name = "rank"
elif method == "dense":
func_name = "dense_rank"
else: # pragma: no cover
msg = f"Method {method} is not yet implemented."
raise NotImplementedError(msg)

def _rank(_input: duckdb.Expression) -> duckdb.Expression:
if descending:
by_sql = f"{_input} desc nulls last"
else:
by_sql = f"{_input} asc nulls last"
order_by_sql = f"order by {by_sql}"
sql = (
f"CASE WHEN {_input} IS NULL THEN NULL "
f"ELSE {func_name}() OVER ({order_by_sql}) END"
)
Comment on lines +715 to +718
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we flip this round (case when not-null), like you did for the spark-like case?

return SQLExpression(sql)

return self._with_callable(_rank)

@property
def str(self: Self) -> DuckDBExprStringNamespace:
return DuckDBExprStringNamespace(self)
Expand Down
26 changes: 26 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,32 @@ def rolling_std(
)
)

def rank(
self,
method: Literal["average", "min", "max", "dense", "ordinal"],
*,
descending: bool,
) -> Self:
if method == "min":
func_name = "rank"
elif method == "dense":
func_name = "dense_rank"
else: # pragma: no cover
msg = f"Method {method} is not yet implemented."
raise NotImplementedError(msg)

def _rank(_input: Column) -> Column:
if descending:
order_by_cols = [self._F.desc_nulls_last(_input)]
else:
order_by_cols = [self._F.asc_nulls_last(_input)]
window = self._Window().orderBy(order_by_cols)
return self._F.when(
_input.isNotNull(), getattr(self._F, func_name)().over(window)
)

return self._with_callable(_rank)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
Expand Down
106 changes: 104 additions & 2 deletions tests/expr_and_series/rank_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import DUCKDB_VERSION
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

rank_methods = ["average", "min", "max", "dense", "ordinal"]

data_int = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2]}
data_float = {"a": [3.1, 6.1, 1.5, 1.5, None, 6.1], "b": [1, 1, 2, 1, 2, 2]}
data_int = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2], "i": [1, 2, 3, 4, 5, 6]}
data_float = {
"a": [3.1, 6.1, 1.5, 1.5, None, 6.1],
"b": [1, 1, 2, 1, 2, 2],
"i": [1, 2, 3, 4, 5, 6],
}

expected = {
"average": [3.0, 4.5, 1.5, 1.5, None, 4.5],
Expand All @@ -23,6 +29,14 @@
"ordinal": [3, 4, 1, 2, None, 5],
}

expected_desc = {
"average": [3.0, 1.5, 4.5, 4.5, None, 1.5],
"min": [3, 1, 4, 4, None, 1],
"max": [3, 2, 5, 5, None, 2],
"dense": [2, 1, 3, 3, None, 1],
"ordinal": [3, 1, 4, 5, None, 2],
}

expected_over = {
"average": [2.0, 3.0, 1.0, 1.0, None, 2.0],
"min": [2, 3, 1, 1, None, 2],
Expand Down Expand Up @@ -135,3 +149,91 @@ def test_invalid_method_raise(constructor_eager: ConstructorEager) -> None:

with pytest.raises(ValueError, match=msg):
df.lazy().collect()["a"].rank(method=method) # type: ignore[arg-type]


@pytest.mark.parametrize("method", rank_methods)
@pytest.mark.parametrize("data", [data_int, data_float])
def test_lazy_rank_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
method: Literal["average", "min", "max", "dense", "ordinal"],
data: dict[str, list[float]],
) -> None:
if (
"pandas_pyarrow" in str(constructor)
and PANDAS_VERSION < (2, 1)
and isinstance(data["a"][0], int)
):
request.applymarker(pytest.mark.xfail)

if (
any(x in str(constructor) for x in ("pyspark", "duckdb"))
and method in {"average", "max", "ordinal"}
) or ("duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)):
request.applymarker(pytest.mark.xfail)

if "dask" in str(constructor):
# `rank` is not implemented in Dask
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(
ValueError,
match=r"`rank` with `method='average' is not supported for pyarrow backend.",
)
if "pyarrow_table" in str(constructor) and method == "average"
else does_not_raise()
)

with context:
df = nw.from_native(constructor(data))

result = df.with_columns(a=nw.col("a").rank(method=method)).sort("i").select("a")
expected_data = {"a": expected[method]}
assert_equal_data(result, expected_data)


@pytest.mark.parametrize("method", rank_methods)
@pytest.mark.parametrize("data", [data_int, data_float])
def test_lazy_rank_expr_desc(
request: pytest.FixtureRequest,
constructor: Constructor,
method: Literal["average", "min", "max", "dense", "ordinal"],
data: dict[str, list[float]],
) -> None:
if (
"pandas_pyarrow" in str(constructor)
and PANDAS_VERSION < (2, 1)
and isinstance(data["a"][0], int)
):
request.applymarker(pytest.mark.xfail)

if (
any(x in str(constructor) for x in ("pyspark", "duckdb"))
and method in {"average", "max", "ordinal"}
) or ("duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)):
request.applymarker(pytest.mark.xfail)

if "dask" in str(constructor):
# `rank` is not implemented in Dask
request.applymarker(pytest.mark.xfail)

context = (
pytest.raises(
ValueError,
match=r"`rank` with `method='average' is not supported for pyarrow backend.",
)
if "pyarrow_table" in str(constructor) and method == "average"
else does_not_raise()
)

with context:
df = nw.from_native(constructor(data))

result = (
df.with_columns(a=nw.col("a").rank(method=method, descending=True))
.sort("i")
.select("a")
)
expected_data = {"a": expected_desc[method]}
assert_equal_data(result, expected_data)
Loading