From c1d68973e70a2d6a64b96cc39d6cdd9f7bf37245 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 28 Mar 2025 18:40:43 +0000 Subject: [PATCH 1/6] feat: add `rank` for Lazy backends --- narwhals/_compliant/expr.py | 1 - narwhals/_dask/expr.py | 1 + narwhals/_duckdb/expr.py | 28 ++++++++++++++++ narwhals/_spark_like/expr.py | 26 +++++++++++++++ tests/expr_and_series/rank_test.py | 52 ++++++++++++++++++++++++++++-- 5 files changed, 105 insertions(+), 3 deletions(-) diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index e48c773730..6b2fbfcb23 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -879,7 +879,6 @@ class LazyExpr( tail: not_implemented = not_implemented() mode: not_implemented = not_implemented() sort: not_implemented = not_implemented() - rank: not_implemented = not_implemented() sample: not_implemented = not_implemented() map_batches: not_implemented = not_implemented() ewm_mean: not_implemented = not_implemented() diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 6f0dead58f..fbed1b85b7 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -674,3 +674,4 @@ def name(self: Self) -> DaskExprNameNamespace: list = not_implemented() # pyright: ignore[reportAssignmentType] struct = not_implemented() # pyright: ignore[reportAssignmentType] + rank = not_implemented() # pyright: ignore[reportAssignmentType] diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 93931fbe62..b5b239da4f 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -683,6 +683,34 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return self._with_callable(func) + def rank( + self, + method: Literal["average", "min", "max", "dense", "ordinal"], + *, + descending: bool = False, + ) -> Self: + if method == "min": + func_name = "rank" + elif method == "dense": + func_name = "dense_rank" + else: # pragma: no cover + msg = f"Method {method} is not yet implemented." + raise NotImplementedError(msg) + + def _rank(_input: duckdb.Expression) -> duckdb.Expression: + if descending: + by_sql = f"{_input} desc nulls first" + else: + by_sql = f"{_input} asc nulls last" + order_by_sql = f"order by {by_sql}" + sql = ( + f"CASE WHEN {_input} IS NULL THEN NULL " + f"ELSE {func_name}() OVER ({order_by_sql}) END" + ) + return SQLExpression(sql) + + return self._from_call(_rank) + @property def str(self: Self) -> DuckDBExprStringNamespace: return DuckDBExprStringNamespace(self) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 2ed12db3ec..a1cba89a47 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -746,6 +746,32 @@ def rolling_std( ) ) + def rank( + self, + method: Literal["average", "min", "max", "dense", "ordinal"], + *, + descending: bool = False, + ) -> Self: + if method == "min": + func_name = "rank" + elif method == "dense": + func_name = "dense_rank" + else: # pragma: no cover + msg = f"Method {method} is not yet implemented." + raise NotImplementedError(msg) + + def _rank(_input: Column) -> Column: + if descending: + order_by_cols = [self._F.desc_nulls_first(_input)] + else: + order_by_cols = [self._F.asc_nulls_last(_input)] + window = self._Window().orderBy(order_by_cols) + return self._F.when(_input.isNull(), self._F.lit(None)).otherwise( + getattr(self._F, func_name)().over(window) + ) + + return self._from_call(_rank) + @property def str(self: Self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) diff --git a/tests/expr_and_series/rank_test.py b/tests/expr_and_series/rank_test.py index 90db9a8453..3d32588a16 100644 --- a/tests/expr_and_series/rank_test.py +++ b/tests/expr_and_series/rank_test.py @@ -6,14 +6,20 @@ import pytest import narwhals.stable.v1 as nw +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data rank_methods = ["average", "min", "max", "dense", "ordinal"] -data_int = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2]} -data_float = {"a": [3.1, 6.1, 1.5, 1.5, None, 6.1], "b": [1, 1, 2, 1, 2, 2]} +data_int = {"a": [3, 6, 1, 1, None, 6], "b": [1, 1, 2, 1, 2, 2], "i": [1, 2, 3, 4, 5, 6]} +data_float = { + "a": [3.1, 6.1, 1.5, 1.5, None, 6.1], + "b": [1, 1, 2, 1, 2, 2], + "i": [1, 2, 3, 4, 5, 6], +} expected = { "average": [3.0, 4.5, 1.5, 1.5, None, 4.5], @@ -135,3 +141,45 @@ def test_invalid_method_raise(constructor_eager: ConstructorEager) -> None: with pytest.raises(ValueError, match=msg): df.lazy().collect()["a"].rank(method=method) # type: ignore[arg-type] + + +@pytest.mark.parametrize("method", rank_methods) +@pytest.mark.parametrize("data", [data_int, data_float]) +def test_lazy_rank_expr( + request: pytest.FixtureRequest, + constructor: Constructor, + method: Literal["average", "min", "max", "dense", "ordinal"], + data: dict[str, list[float]], +) -> None: + if ( + "pandas_pyarrow" in str(constructor) + and PANDAS_VERSION < (2, 1) + and isinstance(data["a"][0], int) + ): + request.applymarker(pytest.mark.xfail) + + if ( + any(x in str(constructor) for x in ("pyspark", "duckdb")) + and method in {"average", "max", "ordinal"} + ) or ("duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)): + request.applymarker(pytest.mark.xfail) + + if "dask" in str(constructor): + # `rank` is not implemented in Dask + request.applymarker(pytest.mark.xfail) + + context = ( + pytest.raises( + ValueError, + match=r"`rank` with `method='average' is not supported for pyarrow backend.", + ) + if "pyarrow_table" in str(constructor) and method == "average" + else does_not_raise() + ) + + with context: + df = nw.from_native(constructor(data)) + + result = df.with_columns(a=nw.col("a").rank(method=method)).sort("i").select("a") + expected_data = {"a": expected[method]} + assert_equal_data(result, expected_data) From 4edcf57c3ccdf1736817a9042f0b0e6d1bc993eb Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 28 Mar 2025 18:53:49 +0000 Subject: [PATCH 2/6] fixup --- narwhals/_duckdb/expr.py | 2 +- narwhals/_spark_like/expr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index b5b239da4f..818a61e699 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -709,7 +709,7 @@ def _rank(_input: duckdb.Expression) -> duckdb.Expression: ) return SQLExpression(sql) - return self._from_call(_rank) + return self._with_callable(_rank) @property def str(self: Self) -> DuckDBExprStringNamespace: diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index a1cba89a47..73013e0cd3 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -770,7 +770,7 @@ def _rank(_input: Column) -> Column: getattr(self._F, func_name)().over(window) ) - return self._from_call(_rank) + return self._with_callable(_rank) @property def str(self: Self) -> SparkLikeExprStringNamespace: From dcc14387e73ff21608bb1801bba426a7bb359d29 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 28 Mar 2025 19:15:50 +0000 Subject: [PATCH 3/6] fix descending flag, add tests for it --- narwhals/_duckdb/expr.py | 2 +- narwhals/_spark_like/expr.py | 2 +- tests/expr_and_series/rank_test.py | 54 ++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 818a61e699..562bfe0aea 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -699,7 +699,7 @@ def rank( def _rank(_input: duckdb.Expression) -> duckdb.Expression: if descending: - by_sql = f"{_input} desc nulls first" + by_sql = f"{_input} desc nulls last" else: by_sql = f"{_input} asc nulls last" order_by_sql = f"order by {by_sql}" diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 73013e0cd3..f5dc229754 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -762,7 +762,7 @@ def rank( def _rank(_input: Column) -> Column: if descending: - order_by_cols = [self._F.desc_nulls_first(_input)] + order_by_cols = [self._F.desc_nulls_last(_input)] else: order_by_cols = [self._F.asc_nulls_last(_input)] window = self._Window().orderBy(order_by_cols) diff --git a/tests/expr_and_series/rank_test.py b/tests/expr_and_series/rank_test.py index 3d32588a16..631964c660 100644 --- a/tests/expr_and_series/rank_test.py +++ b/tests/expr_and_series/rank_test.py @@ -29,6 +29,14 @@ "ordinal": [3, 4, 1, 2, None, 5], } +expected_desc = { + "average": [3.0, 1.5, 4.5, 4.5, None, 1.5], + "min": [3, 1, 4, 4, None, 1], + "max": [3, 2, 5, 5, None, 2], + "dense": [2, 1, 3, 3, None, 1], + "ordinal": [3, 1, 4, 5, None, 2], +} + expected_over = { "average": [2.0, 3.0, 1.0, 1.0, None, 2.0], "min": [2, 3, 1, 1, None, 2], @@ -183,3 +191,49 @@ def test_lazy_rank_expr( result = df.with_columns(a=nw.col("a").rank(method=method)).sort("i").select("a") expected_data = {"a": expected[method]} assert_equal_data(result, expected_data) + + +@pytest.mark.parametrize("method", ["average", "min", "max", "dense", "ordinal"]) +@pytest.mark.parametrize("data", [data_float]) +def test_lazy_rank_expr_desc( + request: pytest.FixtureRequest, + constructor: Constructor, + method: Literal["average", "min", "max", "dense", "ordinal"], + data: dict[str, list[float]], +) -> None: + if ( + "pandas_pyarrow" in str(constructor) + and PANDAS_VERSION < (2, 1) + and isinstance(data["a"][0], int) + ): + request.applymarker(pytest.mark.xfail) + + if ( + any(x in str(constructor) for x in ("pyspark", "duckdb")) + and method in {"average", "max", "ordinal"} + ) or ("duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)): + request.applymarker(pytest.mark.xfail) + + if "dask" in str(constructor): + # `rank` is not implemented in Dask + request.applymarker(pytest.mark.xfail) + + context = ( + pytest.raises( + ValueError, + match=r"`rank` with `method='average' is not supported for pyarrow backend.", + ) + if "pyarrow_table" in str(constructor) and method == "average" + else does_not_raise() + ) + + with context: + df = nw.from_native(constructor(data)) + + result = ( + df.with_columns(a=nw.col("a").rank(method=method, descending=True)) + .sort("i") + .select("a") + ) + expected_data = {"a": expected_desc[method]} + assert_equal_data(result, expected_data) From d643504fbfaf6aa5c9fefc1e7b210566d0ed83bf Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 28 Mar 2025 19:18:55 +0000 Subject: [PATCH 4/6] add data_int --- tests/expr_and_series/rank_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/expr_and_series/rank_test.py b/tests/expr_and_series/rank_test.py index 631964c660..da7868ce56 100644 --- a/tests/expr_and_series/rank_test.py +++ b/tests/expr_and_series/rank_test.py @@ -194,7 +194,7 @@ def test_lazy_rank_expr( @pytest.mark.parametrize("method", ["average", "min", "max", "dense", "ordinal"]) -@pytest.mark.parametrize("data", [data_float]) +@pytest.mark.parametrize("data", [data_int, data_float]) def test_lazy_rank_expr_desc( request: pytest.FixtureRequest, constructor: Constructor, From aa53839fc4d75166f74082b7df843f45ea82e981 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Fri, 28 Mar 2025 19:33:50 +0000 Subject: [PATCH 5/6] simplify spark expr --- narwhals/_spark_like/expr.py | 4 ++-- tests/expr_and_series/rank_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index f5dc229754..9f69bea6f0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -766,8 +766,8 @@ def _rank(_input: Column) -> Column: else: order_by_cols = [self._F.asc_nulls_last(_input)] window = self._Window().orderBy(order_by_cols) - return self._F.when(_input.isNull(), self._F.lit(None)).otherwise( - getattr(self._F, func_name)().over(window) + return self._F.when( + _input.isNotNull(), getattr(self._F, func_name)().over(window) ) return self._with_callable(_rank) diff --git a/tests/expr_and_series/rank_test.py b/tests/expr_and_series/rank_test.py index da7868ce56..747577caee 100644 --- a/tests/expr_and_series/rank_test.py +++ b/tests/expr_and_series/rank_test.py @@ -193,7 +193,7 @@ def test_lazy_rank_expr( assert_equal_data(result, expected_data) -@pytest.mark.parametrize("method", ["average", "min", "max", "dense", "ordinal"]) +@pytest.mark.parametrize("method", rank_methods) @pytest.mark.parametrize("data", [data_int, data_float]) def test_lazy_rank_expr_desc( request: pytest.FixtureRequest, From 90047b5ce3cd78e39bb01dde40789d660d43be19 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Sat, 29 Mar 2025 10:39:06 +0000 Subject: [PATCH 6/6] remove default descending args --- narwhals/_duckdb/expr.py | 2 +- narwhals/_spark_like/expr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 562bfe0aea..9142d1c335 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -687,7 +687,7 @@ def rank( self, method: Literal["average", "min", "max", "dense", "ordinal"], *, - descending: bool = False, + descending: bool, ) -> Self: if method == "min": func_name = "rank" diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 9f69bea6f0..45d6e58b48 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -750,7 +750,7 @@ def rank( self, method: Literal["average", "min", "max", "dense", "ordinal"], *, - descending: bool = False, + descending: bool, ) -> Self: if method == "min": func_name = "rank"