diff --git a/.github/workflows/downstream_tests.yml b/.github/workflows/downstream_tests.yml index bed90f5b6d..3b912cf0ee 100644 --- a/.github/workflows/downstream_tests.yml +++ b/.github/workflows/downstream_tests.yml @@ -196,6 +196,9 @@ jobs: - name: Run `make narwhals-test-integration` run: | cd py-shiny + # Isort seems to behave slightly differently in CI + # so we ignore its output + make format -s make narwhals-test-integration tea-tasting: diff --git a/narwhals/expr.py b/narwhals/expr.py index a86eee2bc3..84a4b93b6b 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -51,6 +51,10 @@ def __init__( self._to_compliant_expr = to_compliant_expr self._metadata = metadata + def _from_callable(self, to_compliant_expr: Callable[[Any], Any]) -> Self: + # Instantiate new Expr keeping metadata unchanged. + return self.__class__(to_compliant_expr, self._metadata) + def __repr__(self: Self) -> str: return f"Narwhals Expr\nmetadata: {self._metadata}\n" @@ -86,9 +90,7 @@ def alias(self: Self, name: str) -> Self: | 1 15 | └──────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).alias(name), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).alias(name)) def pipe( self: Self, @@ -149,9 +151,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: └──────────────────┘ """ _validate_dtype(dtype) - return self.__class__( - lambda plx: self._to_compliant_expr(plx).cast(dtype), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).cast(dtype)) # --- binary --- def __eq__(self: Self, other: Self | Any) -> Self: # type: ignore[override] @@ -348,9 +348,7 @@ def __rmod__(self: Self, other: Any) -> Self: # --- unary --- def __invert__(self: Self) -> Self: - return self.__class__( - lambda plx: self._to_compliant_expr(plx).__invert__(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).__invert__()) def any(self: Self) -> Self: """Return whether any of the values in the column are `True`. @@ -490,7 +488,7 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ - return self.__class__( + return self._from_callable( lambda plx: self._to_compliant_expr(plx).ewm_mean( com=com, span=span, @@ -499,8 +497,7 @@ def ewm_mean( adjust=adjust, min_samples=min_samples, ignore_nulls=ignore_nulls, - ), - self._metadata, + ) ) def mean(self: Self) -> Self: @@ -897,9 +894,7 @@ def abs(self: Self) -> Self: |1 -2 4 2 4| └─────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).abs()) def cum_sum(self: Self, *, reverse: bool = False) -> Self: """Return cumulative sum. @@ -1074,11 +1069,10 @@ def replace_strict( new = list(old.values()) old = list(old.keys()) - return self.__class__( + return self._from_callable( lambda plx: self._to_compliant_expr(plx).replace_strict( old, new, return_dtype=return_dtype - ), - self._metadata, + ) ) def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -1196,11 +1190,10 @@ def is_in(self: Self, other: Any) -> Self: └──────────────────┘ """ if isinstance(other, Iterable) and not isinstance(other, (str, bytes)): - return self.__class__( + return self._from_callable( lambda plx: self._to_compliant_expr(plx).is_in( to_native(other, pass_through=True) ), - self._metadata, ) else: msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -1284,9 +1277,7 @@ def is_null(self: Self) -> Self: |└───────┴────────┴───────────┴───────────┘| └──────────────────────────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_null(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_null()) def is_nan(self: Self) -> Self: """Indicate which values are NaN. @@ -1321,9 +1312,7 @@ def is_nan(self: Self) -> Self: |└───────┴────────┴──────────┴──────────┘| └────────────────────────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_nan(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_nan()) def arg_true(self: Self) -> Self: """Find elements where boolean expression is True. @@ -1424,11 +1413,10 @@ def fill_null( if strategy is not None and strategy not in {"forward", "backward"}: msg = f"strategy not supported: {strategy}" raise ValueError(msg) - return self.__class__( + return self._from_callable( lambda plx: self._to_compliant_expr(plx).fill_null( value=value, strategy=strategy, limit=limit - ), - self._metadata, + ) ) # --- partial reduction --- @@ -1607,9 +1595,7 @@ def is_unique(self: Self) -> Self: |3 1 c False True| └─────────────────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_unique(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_unique()) def null_count(self: Self) -> Self: r"""Count null values. @@ -1830,8 +1816,8 @@ def round(self: Self, decimals: int = 0) -> Self: |2 3.901234 3.9| └──────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).round(decimals), self._metadata + return self._from_callable( + lambda plx: self._to_compliant_expr(plx).round(decimals) ) def len(self: Self) -> Self: @@ -2003,9 +1989,7 @@ def is_finite(self: Self) -> Self: |└──────┴─────────────┘| └──────────────────────┘ """ - return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_finite(), self._metadata - ) + return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_finite()) def cum_count(self: Self, *, reverse: bool = False) -> Self: r"""Return the cumulative count of the non-null values in the column.