From e9b04becbb5b14c2e73b79d9ca271a5332a6a238 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Mar 2025 14:32:49 +0000 Subject: [PATCH 1/5] refactor: Use `PolarsExpr.native` --- narwhals/_polars/expr.py | 84 ++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index f8c2fe41d2..4b3c711277 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -59,14 +59,12 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._from_native_expr( - getattr(self._native_expr, attr)(*args, **kwargs) - ) + return self._from_native_expr(getattr(self.native, attr)(*args, **kwargs)) return func def cast(self: Self, dtype: DType) -> Self: - expr = self._native_expr + expr = self.native dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._from_native_expr(expr.cast(dtype_pl)) @@ -81,7 +79,7 @@ def ewm_mean( min_samples: int, ignore_nulls: bool, ) -> Self: - expr = self._native_expr + expr = self.native extra_kwargs = ( {"min_periods": min_samples} @@ -107,9 +105,9 @@ def ewm_mean( def is_nan(self: Self) -> Self: if self._backend_version < (1, 18): # pragma: no cover return self._from_native_expr( - pl.when(self._native_expr.is_not_null()).then(self._native_expr.is_nan()) + pl.when(self.native.is_not_null()).then(self.native.is_nan()) ) - return self._from_native_expr(self._native_expr.is_nan()) + return self._from_native_expr(self.native.is_nan()) def over( self: Self, @@ -120,11 +118,9 @@ def over( if order_by: msg = "`order_by` in Polars requires version 1.10 or greater" raise NotImplementedError(msg) - return self._from_native_expr( - self._native_expr.over(partition_by or pl.lit(1)) - ) + return self._from_native_expr(self.native.over(partition_by or pl.lit(1))) return self._from_native_expr( - self._native_expr.over(partition_by or pl.lit(1), order_by=order_by) + self.native.over(partition_by or pl.lit(1), order_by=order_by) ) def rolling_var( @@ -145,7 +141,7 @@ def rolling_var( else {"min_samples": min_samples} ) return self._from_native_expr( - self._native_expr.rolling_var( + self.native.rolling_var( window_size=window_size, center=center, ddof=ddof, @@ -171,7 +167,7 @@ def rolling_std( ) return self._from_native_expr( - self._native_expr.rolling_std( + self.native.rolling_std( window_size=window_size, center=center, ddof=ddof, @@ -189,7 +185,7 @@ def rolling_sum( ) return self._from_native_expr( - self._native_expr.rolling_sum( + self.native.rolling_sum( window_size=window_size, center=center, **extra_kwargs, # type: ignore[arg-type] @@ -210,7 +206,7 @@ def rolling_mean( ) return self._from_native_expr( - self._native_expr.rolling_mean( + self.native.rolling_mean( window_size=window_size, center=center, **extra_kwargs, # type: ignore[arg-type] @@ -227,15 +223,15 @@ def map_batches( return_dtype, self._version, self._backend_version ) return self._from_native_expr( - self._native_expr.map_batches(function, return_dtype_pl) + self.native.map_batches(function, return_dtype_pl) ) else: - return self._from_native_expr(self._native_expr.map_batches(function)) + return self._from_native_expr(self.native.map_batches(function)) def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: - expr = self._native_expr + expr = self.native return_dtype_pl = ( narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype @@ -249,63 +245,59 @@ def replace_strict( ) def __eq__(self: Self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self._native_expr.__eq__(extract_native(other))) # type: ignore[operator] + return self._from_native_expr(self.native.__eq__(extract_native(other))) # type: ignore[operator] def __ne__(self: Self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self._native_expr.__ne__(extract_native(other))) # type: ignore[operator] + return self._from_native_expr(self.native.__ne__(extract_native(other))) # type: ignore[operator] def __ge__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__ge__(extract_native(other))) + return self._from_native_expr(self.native.__ge__(extract_native(other))) def __gt__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__gt__(extract_native(other))) + return self._from_native_expr(self.native.__gt__(extract_native(other))) def __le__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__le__(extract_native(other))) + return self._from_native_expr(self.native.__le__(extract_native(other))) def __lt__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__lt__(extract_native(other))) + return self._from_native_expr(self.native.__lt__(extract_native(other))) def __and__(self: Self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self._native_expr.__and__(extract_native(other))) # type: ignore[operator] + return self._from_native_expr(self.native.__and__(extract_native(other))) # type: ignore[operator] def __or__(self: Self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self._native_expr.__or__(extract_native(other))) # type: ignore[operator] + return self._from_native_expr(self.native.__or__(extract_native(other))) # type: ignore[operator] def __add__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__add__(extract_native(other))) + return self._from_native_expr(self.native.__add__(extract_native(other))) def __sub__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__sub__(extract_native(other))) + return self._from_native_expr(self.native.__sub__(extract_native(other))) def __mul__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__mul__(extract_native(other))) + return self._from_native_expr(self.native.__mul__(extract_native(other))) def __pow__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__pow__(extract_native(other))) + return self._from_native_expr(self.native.__pow__(extract_native(other))) def __truediv__(self: Self, other: Any) -> Self: - return self._from_native_expr( - self._native_expr.__truediv__(extract_native(other)) - ) + return self._from_native_expr(self.native.__truediv__(extract_native(other))) def __floordiv__(self: Self, other: Any) -> Self: - return self._from_native_expr( - self._native_expr.__floordiv__(extract_native(other)) - ) + return self._from_native_expr(self.native.__floordiv__(extract_native(other))) def __mod__(self: Self, other: Any) -> Self: - return self._from_native_expr(self._native_expr.__mod__(extract_native(other))) + return self._from_native_expr(self.native.__mod__(extract_native(other))) def __invert__(self: Self) -> Self: - return self._from_native_expr(self._native_expr.__invert__()) + return self._from_native_expr(self.native.__invert__()) def cum_count(self: Self, *, reverse: bool) -> Self: if self._backend_version < (0, 20, 4): - not_null = ~self._native_expr.is_null() + not_null = ~self.native.is_null() result = not_null.cum_sum(reverse=reverse) else: - result = self._native_expr.cum_count(reverse=reverse) + result = self.native.cum_count(reverse=reverse) return self._from_native_expr(result) @@ -342,7 +334,7 @@ def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._compliant_expr._from_native_expr( - getattr(self._compliant_expr._native_expr.dt, attr)(*args, **kwargs) + getattr(self._compliant_expr.native.dt, attr)(*args, **kwargs) ) return func @@ -356,7 +348,7 @@ def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._compliant_expr._from_native_expr( - getattr(self._compliant_expr._native_expr.str, attr)(*args, **kwargs) + getattr(self._compliant_expr.native.str, attr)(*args, **kwargs) ) return func @@ -370,7 +362,7 @@ def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._compliant_expr._from_native_expr( - getattr(self._compliant_expr._native_expr.cat, attr)(*args, **kwargs) + getattr(self._compliant_expr.native.cat, attr)(*args, **kwargs) ) return func @@ -384,7 +376,7 @@ def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._compliant_expr._from_native_expr( - getattr(self._compliant_expr._native_expr.name, attr)(*args, **kwargs) + getattr(self._compliant_expr.native.name, attr)(*args, **kwargs) ) return func @@ -414,7 +406,7 @@ def __getattr__( def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( - getattr(self._expr._native_expr.list, attr)(*args, **kwargs) + getattr(self._expr.native.list, attr)(*args, **kwargs) ) return func @@ -430,7 +422,7 @@ def __getattr__( def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] return self._expr._from_native_expr( - getattr(self._expr._native_expr.struct, attr)(*args, **kwargs) + getattr(self._expr.native.struct, attr)(*args, **kwargs) ) return func From 1d3ab8e1f15e864be8a9671eeede98470e681427 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Mar 2025 14:34:20 +0000 Subject: [PATCH 2/5] refactor: `_from_native_expr` -> `_with_native` --- narwhals/_polars/expr.py | 82 ++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 4b3c711277..9ce554e68e 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -39,7 +39,7 @@ def native(self) -> pl.Expr: def __repr__(self: Self) -> str: # pragma: no cover return "PolarsExpr" - def _from_native_expr(self: Self, expr: pl.Expr) -> Self: + def _with_native(self: Self, expr: pl.Expr) -> Self: return self.__class__( expr, version=self._version, backend_version=self._backend_version ) @@ -59,14 +59,14 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._from_native_expr(getattr(self.native, attr)(*args, **kwargs)) + return self._with_native(getattr(self.native, attr)(*args, **kwargs)) return func def cast(self: Self, dtype: DType) -> Self: expr = self.native dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) - return self._from_native_expr(expr.cast(dtype_pl)) + return self._with_native(expr.cast(dtype_pl)) def ewm_mean( self: Self, @@ -97,17 +97,17 @@ def ewm_mean( **extra_kwargs, ) if self._backend_version < (1,): # pragma: no cover - return self._from_native_expr( + return self._with_native( pl.when(~expr.is_null()).then(native_expr).otherwise(None) ) - return self._from_native_expr(native_expr) + return self._with_native(native_expr) def is_nan(self: Self) -> Self: if self._backend_version < (1, 18): # pragma: no cover - return self._from_native_expr( + return self._with_native( pl.when(self.native.is_not_null()).then(self.native.is_nan()) ) - return self._from_native_expr(self.native.is_nan()) + return self._with_native(self.native.is_nan()) def over( self: Self, @@ -118,8 +118,8 @@ def over( if order_by: msg = "`order_by` in Polars requires version 1.10 or greater" raise NotImplementedError(msg) - return self._from_native_expr(self.native.over(partition_by or pl.lit(1))) - return self._from_native_expr( + return self._with_native(self.native.over(partition_by or pl.lit(1))) + return self._with_native( self.native.over(partition_by or pl.lit(1), order_by=order_by) ) @@ -140,7 +140,7 @@ def rolling_var( if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - return self._from_native_expr( + return self._with_native( self.native.rolling_var( window_size=window_size, center=center, @@ -166,7 +166,7 @@ def rolling_std( else {"min_samples": min_samples} ) - return self._from_native_expr( + return self._with_native( self.native.rolling_std( window_size=window_size, center=center, @@ -184,7 +184,7 @@ def rolling_sum( else {"min_samples": min_samples} ) - return self._from_native_expr( + return self._with_native( self.native.rolling_sum( window_size=window_size, center=center, @@ -205,7 +205,7 @@ def rolling_mean( else {"min_samples": min_samples} ) - return self._from_native_expr( + return self._with_native( self.native.rolling_mean( window_size=window_size, center=center, @@ -222,11 +222,9 @@ def map_batches( return_dtype_pl = narwhals_to_native_dtype( return_dtype, self._version, self._backend_version ) - return self._from_native_expr( - self.native.map_batches(function, return_dtype_pl) - ) + return self._with_native(self.native.map_batches(function, return_dtype_pl)) else: - return self._from_native_expr(self.native.map_batches(function)) + return self._with_native(self.native.map_batches(function)) def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None @@ -240,57 +238,57 @@ def replace_strict( if self._backend_version < (1,): msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" raise NotImplementedError(msg) - return self._from_native_expr( + return self._with_native( expr.replace_strict(old, new, return_dtype=return_dtype_pl) ) def __eq__(self: Self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self.native.__eq__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] def __ne__(self: Self, other: object) -> Self: # type: ignore[override] - return self._from_native_expr(self.native.__ne__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator] def __ge__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__ge__(extract_native(other))) + return self._with_native(self.native.__ge__(extract_native(other))) def __gt__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__gt__(extract_native(other))) + return self._with_native(self.native.__gt__(extract_native(other))) def __le__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__le__(extract_native(other))) + return self._with_native(self.native.__le__(extract_native(other))) def __lt__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__lt__(extract_native(other))) + return self._with_native(self.native.__lt__(extract_native(other))) def __and__(self: Self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self.native.__and__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator] def __or__(self: Self, other: PolarsExpr | bool | Any) -> Self: - return self._from_native_expr(self.native.__or__(extract_native(other))) # type: ignore[operator] + return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator] def __add__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__add__(extract_native(other))) + return self._with_native(self.native.__add__(extract_native(other))) def __sub__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__sub__(extract_native(other))) + return self._with_native(self.native.__sub__(extract_native(other))) def __mul__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__mul__(extract_native(other))) + return self._with_native(self.native.__mul__(extract_native(other))) def __pow__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__pow__(extract_native(other))) + return self._with_native(self.native.__pow__(extract_native(other))) def __truediv__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__truediv__(extract_native(other))) + return self._with_native(self.native.__truediv__(extract_native(other))) def __floordiv__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__floordiv__(extract_native(other))) + return self._with_native(self.native.__floordiv__(extract_native(other))) def __mod__(self: Self, other: Any) -> Self: - return self._from_native_expr(self.native.__mod__(extract_native(other))) + return self._with_native(self.native.__mod__(extract_native(other))) def __invert__(self: Self) -> Self: - return self._from_native_expr(self.native.__invert__()) + return self._with_native(self.native.__invert__()) def cum_count(self: Self, *, reverse: bool) -> Self: if self._backend_version < (0, 20, 4): @@ -299,7 +297,7 @@ def cum_count(self: Self, *, reverse: bool) -> Self: else: result = self.native.cum_count(reverse=reverse) - return self._from_native_expr(result) + return self._with_native(result) @property def dt(self: Self) -> PolarsExprDateTimeNamespace: @@ -333,7 +331,7 @@ def __init__(self: Self, expr: PolarsExpr) -> None: def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._compliant_expr._from_native_expr( + return self._compliant_expr._with_native( getattr(self._compliant_expr.native.dt, attr)(*args, **kwargs) ) @@ -347,7 +345,7 @@ def __init__(self: Self, expr: PolarsExpr) -> None: def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._compliant_expr._from_native_expr( + return self._compliant_expr._with_native( getattr(self._compliant_expr.native.str, attr)(*args, **kwargs) ) @@ -361,7 +359,7 @@ def __init__(self: Self, expr: PolarsExpr) -> None: def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._compliant_expr._from_native_expr( + return self._compliant_expr._with_native( getattr(self._compliant_expr.native.cat, attr)(*args, **kwargs) ) @@ -375,7 +373,7 @@ def __init__(self: Self, expr: PolarsExpr) -> None: def __getattr__(self: Self, attr: str) -> Callable[[Any], PolarsExpr]: def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._compliant_expr._from_native_expr( + return self._compliant_expr._with_native( getattr(self._compliant_expr.native.name, attr)(*args, **kwargs) ) @@ -397,7 +395,7 @@ def len(self: Self) -> PolarsExpr: elif self._expr._backend_version < (1, 17): # pragma: no cover native_result = native_result.cast(pl.UInt32()) - return self._expr._from_native_expr(native_result) + return self._expr._with_native(native_result) # TODO(FBruzzesi): Remove `pragma: no cover` once other namespace methods are added def __getattr__( @@ -405,7 +403,7 @@ def __getattr__( ) -> Callable[[Any], PolarsExpr]: # pragma: no cover def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._expr._from_native_expr( + return self._expr._with_native( getattr(self._expr.native.list, attr)(*args, **kwargs) ) @@ -421,7 +419,7 @@ def __getattr__( ) -> Callable[[Any], PolarsExpr]: # pragma: no cover def func(*args: Any, **kwargs: Any) -> PolarsExpr: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return self._expr._from_native_expr( + return self._expr._with_native( getattr(self._expr.native.struct, attr)(*args, **kwargs) ) From 61f7c55379972eac994e3b330089137550dda17f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Mar 2025 14:57:25 +0000 Subject: [PATCH 3/5] refactor: Shrinking `PolarsExpr` --- narwhals/_polars/expr.py | 154 +++++++++++++-------------------------- 1 file changed, 49 insertions(+), 105 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 9ce554e68e..9b6b528697 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -40,17 +40,11 @@ def __repr__(self: Self) -> str: # pragma: no cover return "PolarsExpr" def _with_native(self: Self, expr: pl.Expr) -> Self: - return self.__class__( - expr, version=self._version, backend_version=self._backend_version - ) + return self.__class__(expr, self._version, self._backend_version) @classmethod def _from_series(cls, series: Any) -> Self: - return cls( - series.native, - version=series._version, - backend_version=series._backend_version, - ) + return cls(series.native, series._version, series._backend_version) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: # Let Polars do its thing. @@ -64,9 +58,8 @@ def func(*args: Any, **kwargs: Any) -> Any: return func def cast(self: Self, dtype: DType) -> Self: - expr = self.native dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) - return self._with_native(expr.cast(dtype_pl)) + return self._with_native(self.native.cast(dtype_pl)) def ewm_mean( self: Self, @@ -79,168 +72,121 @@ def ewm_mean( min_samples: int, ignore_nulls: bool, ) -> Self: - expr = self.native - - extra_kwargs = ( + kwds: dict[str, Any] = ( {"min_periods": min_samples} if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - - native_expr = expr.ewm_mean( + native = self.native.ewm_mean( com=com, span=span, half_life=half_life, alpha=alpha, adjust=adjust, ignore_nulls=ignore_nulls, - **extra_kwargs, + **kwds, ) if self._backend_version < (1,): # pragma: no cover - return self._with_native( - pl.when(~expr.is_null()).then(native_expr).otherwise(None) - ) - return self._with_native(native_expr) + native = pl.when(~self.native.is_null()).then(native).otherwise(None) + return self._with_native(native) def is_nan(self: Self) -> Self: - if self._backend_version < (1, 18): # pragma: no cover - return self._with_native( - pl.when(self.native.is_not_null()).then(self.native.is_nan()) - ) - return self._with_native(self.native.is_nan()) + if self._backend_version >= (1, 18): + native = self.native.is_nan() + else: # pragma: no cover + native = pl.when(self.native.is_not_null()).then(self.native.is_nan()) + return self._with_native(native) def over( - self: Self, - partition_by: Sequence[str], - order_by: Sequence[str] | None, + self: Self, partition_by: Sequence[str], order_by: Sequence[str] | None ) -> Self: if self._backend_version < (1, 9): if order_by: msg = "`order_by` in Polars requires version 1.10 or greater" raise NotImplementedError(msg) - return self._with_native(self.native.over(partition_by or pl.lit(1))) - return self._with_native( - self.native.over(partition_by or pl.lit(1), order_by=order_by) - ) + native = self.native.over(partition_by or pl.lit(1)) + else: + native = self.native.over(partition_by or pl.lit(1), order_by=order_by) + return self._with_native(native) def rolling_var( - self: Self, - window_size: int, - *, - min_samples: int, - center: bool, - ddof: int, + self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if self._backend_version < (1,): # pragma: no cover msg = "`rolling_var` not implemented for polars older than 1.0" raise NotImplementedError(msg) - - extra_kwargs = ( + kwds: dict[str, Any] = ( {"min_periods": min_samples} if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - return self._with_native( - self.native.rolling_var( - window_size=window_size, - center=center, - ddof=ddof, - **extra_kwargs, # type: ignore[arg-type] - ) + native = self.native.rolling_var( + window_size=window_size, center=center, ddof=ddof, **kwds ) + return self._with_native(native) def rolling_std( - self: Self, - window_size: int, - *, - min_samples: int, - center: bool, - ddof: int, + self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if self._backend_version < (1,): # pragma: no cover msg = "`rolling_std` not implemented for polars older than 1.0" raise NotImplementedError(msg) - extra_kwargs = ( + kwds: dict[str, Any] = ( {"min_periods": min_samples} if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - - return self._with_native( - self.native.rolling_std( - window_size=window_size, - center=center, - ddof=ddof, - **extra_kwargs, # type: ignore[arg-type] - ) + native = self.native.rolling_std( + window_size=window_size, center=center, ddof=ddof, **kwds ) + return self._with_native(native) def rolling_sum( self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: - extra_kwargs = ( + kwds: dict[str, Any] = ( {"min_periods": min_samples} if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - - return self._with_native( - self.native.rolling_sum( - window_size=window_size, - center=center, - **extra_kwargs, # type: ignore[arg-type] - ) - ) + native = self.native.rolling_sum(window_size=window_size, center=center, **kwds) + return self._with_native(native) def rolling_mean( - self: Self, - window_size: int, - *, - min_samples: int, - center: bool, + self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: - extra_kwargs = ( + kwds: dict[str, Any] = ( {"min_periods": min_samples} if self._backend_version < (1, 21, 0) else {"min_samples": min_samples} ) - - return self._with_native( - self.native.rolling_mean( - window_size=window_size, - center=center, - **extra_kwargs, # type: ignore[arg-type] - ) - ) + native = self.native.rolling_mean(window_size=window_size, center=center, **kwds) + return self._with_native(native) def map_batches( - self: Self, - function: Callable[..., Self], - return_dtype: DType | None, - ) -> Self: - if return_dtype is not None: - return_dtype_pl = narwhals_to_native_dtype( - return_dtype, self._version, self._backend_version - ) - return self._with_native(self.native.map_batches(function, return_dtype_pl)) - else: - return self._with_native(self.native.map_batches(function)) - - def replace_strict( - self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None + self: Self, function: Callable[..., Self], return_dtype: DType | None ) -> Self: - expr = self.native return_dtype_pl = ( narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) if return_dtype else None ) + native = self.native.map_batches(function, return_dtype_pl) + return self._with_native(native) + + def replace_strict( + self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None + ) -> Self: if self._backend_version < (1,): msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}" raise NotImplementedError(msg) - return self._with_native( - expr.replace_strict(old, new, return_dtype=return_dtype_pl) + return_dtype_pl = ( + narwhals_to_native_dtype(return_dtype, self._version, self._backend_version) + if return_dtype + else None ) + native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl) + return self._with_native(native) def __eq__(self: Self, other: object) -> Self: # type: ignore[override] return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator] @@ -292,11 +238,9 @@ def __invert__(self: Self) -> Self: def cum_count(self: Self, *, reverse: bool) -> Self: if self._backend_version < (0, 20, 4): - not_null = ~self.native.is_null() - result = not_null.cum_sum(reverse=reverse) + result = (~self.native.is_null()).cum_sum(reverse=reverse) else: result = self.native.cum_count(reverse=reverse) - return self._with_native(result) @property From ebf63b0502bddb854d7d5c9339000e629379b763 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Mar 2025 15:08:10 +0000 Subject: [PATCH 4/5] refactor: Add `PolarsExpr._renamed_min_periods` Shorter, maintained in one place --- narwhals/_polars/expr.py | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 9b6b528697..8cbd71cd89 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -57,6 +57,10 @@ def func(*args: Any, **kwargs: Any) -> Any: return func + def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: + name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" + return {name: min_samples} + def cast(self: Self, dtype: DType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) return self._with_native(self.native.cast(dtype_pl)) @@ -72,11 +76,6 @@ def ewm_mean( min_samples: int, ignore_nulls: bool, ) -> Self: - kwds: dict[str, Any] = ( - {"min_periods": min_samples} - if self._backend_version < (1, 21, 0) - else {"min_samples": min_samples} - ) native = self.native.ewm_mean( com=com, span=span, @@ -84,7 +83,7 @@ def ewm_mean( alpha=alpha, adjust=adjust, ignore_nulls=ignore_nulls, - **kwds, + **self._renamed_min_periods(min_samples), ) if self._backend_version < (1,): # pragma: no cover native = pl.when(~self.native.is_null()).then(native).otherwise(None) @@ -115,11 +114,7 @@ def rolling_var( if self._backend_version < (1,): # pragma: no cover msg = "`rolling_var` not implemented for polars older than 1.0" raise NotImplementedError(msg) - kwds: dict[str, Any] = ( - {"min_periods": min_samples} - if self._backend_version < (1, 21, 0) - else {"min_samples": min_samples} - ) + kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_var( window_size=window_size, center=center, ddof=ddof, **kwds ) @@ -131,11 +126,7 @@ def rolling_std( if self._backend_version < (1,): # pragma: no cover msg = "`rolling_std` not implemented for polars older than 1.0" raise NotImplementedError(msg) - kwds: dict[str, Any] = ( - {"min_periods": min_samples} - if self._backend_version < (1, 21, 0) - else {"min_samples": min_samples} - ) + kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_std( window_size=window_size, center=center, ddof=ddof, **kwds ) @@ -144,22 +135,14 @@ def rolling_std( def rolling_sum( self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: - kwds: dict[str, Any] = ( - {"min_periods": min_samples} - if self._backend_version < (1, 21, 0) - else {"min_samples": min_samples} - ) + kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_sum(window_size=window_size, center=center, **kwds) return self._with_native(native) def rolling_mean( self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: - kwds: dict[str, Any] = ( - {"min_periods": min_samples} - if self._backend_version < (1, 21, 0) - else {"min_samples": min_samples} - ) + kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_mean(window_size=window_size, center=center, **kwds) return self._with_native(native) From 712a388b111377ef403e4bb9232f0e32529c0fa4 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 29 Mar 2025 15:11:56 +0000 Subject: [PATCH 5/5] fix(typing): Avoid redef --- narwhals/_polars/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 8cbd71cd89..af97d6b89a 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -316,7 +316,7 @@ def len(self: Self) -> PolarsExpr: native_result = native_expr.list.len() if self._expr._backend_version < (1, 16): # pragma: no cover - native_result: pl.Expr = ( # type: ignore[no-redef] + native_result = ( pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) ) elif self._expr._backend_version < (1, 17): # pragma: no cover