From a22ba9afb88854194f68c81a2ac956e7189b94eb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 20 Feb 2025 18:18:53 +0000 Subject: [PATCH 1/2] chore: remove `kwargs` from `selector` function --- narwhals/_arrow/expr.py | 8 +++----- narwhals/_arrow/namespace.py | 17 +++-------------- narwhals/_arrow/selectors.py | 23 +++++++---------------- narwhals/_pandas_like/expr.py | 18 +++++++++--------- narwhals/_pandas_like/selectors.py | 23 ++++++++--------------- 5 files changed, 30 insertions(+), 59 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index be608d3b3c..ef54b0ba1c 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -43,7 +43,7 @@ def __init__( alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, - kwargs: dict[str, Any], + kwargs: dict[str, Any] | None = None, ) -> None: self._call = call self._depth = depth @@ -53,7 +53,7 @@ def __init__( self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version - self._kwargs = kwargs + self._kwargs = {} if kwargs is None else kwargs def __repr__(self: Self) -> str: # pragma: no cover return f"ArrowExpr(depth={self._depth}, function_name={self._function_name}, " @@ -117,7 +117,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: alias_output_names=None, backend_version=backend_version, version=version, - kwargs={}, ) @classmethod @@ -148,7 +147,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: alias_output_names=None, backend_version=backend_version, version=version, - kwargs={}, ) def __narwhals_namespace__(self: Self) -> ArrowNamespace: @@ -315,7 +313,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: alias_output_names=alias_output_names, backend_version=self._backend_version, version=self._version, - kwargs={**self._kwargs, "name": name}, + kwargs=self._kwargs, ) def null_count(self: Self) -> Self: diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 062ab0aaa7..41b6e5fa32 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -49,7 +49,7 @@ def _create_expr_from_callable( function_name: str, evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, - kwargs: dict[str, Any], + kwargs: dict[str, Any] | None = None, ) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr @@ -75,7 +75,6 @@ def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr: alias_output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={}, ) def _create_series_from_scalar( @@ -142,7 +141,6 @@ def len(self: Self) -> ArrowExpr: alias_output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={}, ) def all(self: Self) -> ArrowExpr: @@ -165,7 +163,6 @@ def all(self: Self) -> ArrowExpr: alias_output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={}, ) def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr: @@ -188,7 +185,6 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: alias_output_names=None, backend_version=self._backend_version, version=self._version, - kwargs={}, ) def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: @@ -202,7 +198,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - kwargs={"exprs": exprs}, ) def any_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: @@ -387,11 +382,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), alias_output_names=combine_alias_output_names(*exprs), - kwargs={ - "exprs": exprs, - "separator": separator, - "ignore_nulls": ignore_nulls, - }, ) @@ -466,7 +456,6 @@ def then(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: alias_output_names=getattr(value, "_alias_output_names", None), backend_version=self._backend_version, version=self._version, - kwargs={"value": value}, ) @@ -481,7 +470,7 @@ def __init__( alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, - kwargs: dict[str, Any], + kwargs: dict[str, Any] | None = None, ) -> None: self._backend_version = backend_version self._version = version @@ -490,7 +479,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names - self._kwargs = kwargs + self._kwargs = {} if kwargs is None else kwargs def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr: # type ignore because we are setting the `_call` attribute to a diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index f5a6a71a2c..ec045c9e15 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -36,7 +36,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return selector(self, func, evaluate_output_names, {"dtypes": dtypes}) + return selector(self, func, evaluate_output_names) def matches(self: Self, pattern: str) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: @@ -45,7 +45,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return selector(self, func, evaluate_output_names, {"pattern": pattern}) + return selector(self, func, evaluate_output_names) def numeric(self: Self) -> ArrowSelector: dtypes = import_dtypes_module(self._version) @@ -82,7 +82,7 @@ def all(self: Self) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: return [df[col] for col in df.columns] - return selector(self, func, lambda df: df.columns, {}) + return selector(self, func, lambda df: df.columns) def datetime( self: Self, @@ -118,7 +118,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: ) ] - return selector(self, func, evaluate_output_names, {}) + return selector(self, func, evaluate_output_names) class ArrowSelector(ArrowExpr): @@ -134,7 +134,6 @@ def _to_expr(self: Self) -> ArrowExpr: alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, - kwargs=self._kwargs, ) def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any: @@ -151,9 +150,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() - other @@ -175,9 +172,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() | other @@ -195,9 +190,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() & other @@ -210,7 +203,6 @@ def selector( context: _LimitedContext, call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], - kwargs: dict[str, Any], /, ) -> ArrowSelector: return ArrowSelector( @@ -221,5 +213,4 @@ def selector( alias_output_names=None, backend_version=context._backend_version, version=context._version, - kwargs=kwargs, ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index ab13623e02..b725a71ca2 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -431,22 +431,20 @@ def over(self: Self, keys: list[str]) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names, aliases = evaluate_output_names_and_aliases(self, df, []) - reverse = self._kwargs.get("reverse", False) - if reverse: - msg = ( - "Cumulative operation with `reverse=True` is not supported in " - "over context for pandas-like backend." - ) - raise NotImplementedError(msg) - + unsupported_reverse_msg = ( + "Cumulative operation with `reverse=True` is not supported in " + "over context for pandas-like backend." + ) if function_name == "cum_count": + if self._kwargs["reverse"]: + raise NotImplementedError(unsupported_reverse_msg) plx = self.__narwhals_namespace__() df = df.with_columns(~plx.col(*output_names).is_null()) if function_name == "shift": kwargs = {"periods": self._kwargs["n"]} elif function_name == "rank": - _method = self._kwargs.get("method", "average") + _method = self._kwargs["method"] kwargs = { "method": "first" if _method == "ordinal" else _method, "ascending": not self._kwargs["descending"], @@ -454,6 +452,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: "pct": False, } else: # Cumulative operation + if self._kwargs["reverse"]: + raise NotImplementedError(unsupported_reverse_msg) kwargs = {"skipna": True} res_native = getattr( diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 05afa05569..4b7a2ef2e0 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -37,7 +37,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return selector(self, func, evaluate_output_names, {"dtypes": dtypes}) + return selector(self, func, evaluate_output_names) def matches(self: Self, pattern: str) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -46,7 +46,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return selector(self, func, evaluate_output_names, {"pattern": pattern}) + return selector(self, func, evaluate_output_names) def numeric(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) @@ -83,7 +83,7 @@ def all(self: Self) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [df[col] for col in df.columns] - return selector(self, func, lambda df: df.columns, {}) + return selector(self, func, lambda df: df.columns) def datetime( self: Self, @@ -119,7 +119,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: ) ] - return selector(self, func, evaluate_output_names, {}) + return selector(self, func, evaluate_output_names) class PandasSelector(PandasLikeExpr): @@ -155,9 +155,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() - other @@ -179,9 +177,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() | other @@ -199,9 +195,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return selector( - self, call, evaluate_output_names, {**self._kwargs, "other": other} - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() & other @@ -213,7 +207,6 @@ def selector( context: _FullContext, call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], - kwargs: dict[str, Any], /, ) -> PandasSelector: return PandasSelector( @@ -225,5 +218,5 @@ def selector( implementation=context._implementation, backend_version=context._backend_version, version=context._version, - kwargs=kwargs, + kwargs={}, ) From b4e74afde5bfad122f12b37a4f9752118d037504 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 20 Feb 2025 18:24:39 +0000 Subject: [PATCH 2/2] simplify --- narwhals/_arrow/expr.py | 2 +- narwhals/_arrow/namespace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index ef54b0ba1c..d3378ea8eb 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -53,7 +53,7 @@ def __init__( self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version - self._kwargs = {} if kwargs is None else kwargs + self._kwargs = kwargs or {} def __repr__(self: Self) -> str: # pragma: no cover return f"ArrowExpr(depth={self._depth}, function_name={self._function_name}, " diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 41b6e5fa32..f11e69af8e 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -479,7 +479,7 @@ def __init__( self._function_name = function_name self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names - self._kwargs = {} if kwargs is None else kwargs + self._kwargs = kwargs or {} def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr: # type ignore because we are setting the `_call` attribute to a