Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/downstream_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ jobs:
- name: Run `make narwhals-test-integration`
run: |
cd py-shiny
# Isort seems to behave slightly differently in CI
# so we ignore its output
make format -s
make narwhals-test-integration

tea-tasting:
Expand Down
58 changes: 21 additions & 37 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(
self._to_compliant_expr = to_compliant_expr
self._metadata = metadata

def _from_callable(self, to_compliant_expr: Callable[[Any], Any]) -> Self:
# Instantiate new Expr keeping metadata unchanged.
return self.__class__(to_compliant_expr, self._metadata)

def __repr__(self: Self) -> str:
return f"Narwhals Expr\nmetadata: {self._metadata}\n"

Expand Down Expand Up @@ -86,9 +90,7 @@ def alias(self: Self, name: str) -> Self:
| 1 15 |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).alias(name), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).alias(name))

def pipe(
self: Self,
Expand Down Expand Up @@ -149,9 +151,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self:
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
_validate_dtype(dtype)
return self.__class__(
lambda plx: self._to_compliant_expr(plx).cast(dtype), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).cast(dtype))

# --- binary ---
def __eq__(self: Self, other: Self | Any) -> Self: # type: ignore[override]
Expand Down Expand Up @@ -348,9 +348,7 @@ def __rmod__(self: Self, other: Any) -> Self:

# --- unary ---
def __invert__(self: Self) -> Self:
return self.__class__(
lambda plx: self._to_compliant_expr(plx).__invert__(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).__invert__())

def any(self: Self) -> Self:
"""Return whether any of the values in the column are `True`.
Expand Down Expand Up @@ -490,7 +488,7 @@ def ewm_mean(
β”‚ 2.428571 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
return self._from_callable(
lambda plx: self._to_compliant_expr(plx).ewm_mean(
com=com,
span=span,
Expand All @@ -499,8 +497,7 @@ def ewm_mean(
adjust=adjust,
min_samples=min_samples,
ignore_nulls=ignore_nulls,
),
self._metadata,
)
)

def mean(self: Self) -> Self:
Expand Down Expand Up @@ -897,9 +894,7 @@ def abs(self: Self) -> Self:
|1 -2 4 2 4|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).abs(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).abs())

def cum_sum(self: Self, *, reverse: bool = False) -> Self:
"""Return cumulative sum.
Expand Down Expand Up @@ -1074,11 +1069,10 @@ def replace_strict(
new = list(old.values())
old = list(old.keys())

return self.__class__(
return self._from_callable(
lambda plx: self._to_compliant_expr(plx).replace_strict(
old, new, return_dtype=return_dtype
),
self._metadata,
)
)

def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self:
Expand Down Expand Up @@ -1196,11 +1190,10 @@ def is_in(self: Self, other: Any) -> Self:
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
if isinstance(other, Iterable) and not isinstance(other, (str, bytes)):
return self.__class__(
return self._from_callable(
lambda plx: self._to_compliant_expr(plx).is_in(
to_native(other, pass_through=True)
),
self._metadata,
)
else:
msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead."
Expand Down Expand Up @@ -1284,9 +1277,7 @@ def is_null(self: Self) -> Self:
|β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_null(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_null())

def is_nan(self: Self) -> Self:
"""Indicate which values are NaN.
Expand Down Expand Up @@ -1321,9 +1312,7 @@ def is_nan(self: Self) -> Self:
|β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_nan(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_nan())

def arg_true(self: Self) -> Self:
"""Find elements where boolean expression is True.
Expand Down Expand Up @@ -1424,11 +1413,10 @@ def fill_null(
if strategy is not None and strategy not in {"forward", "backward"}:
msg = f"strategy not supported: {strategy}"
raise ValueError(msg)
return self.__class__(
return self._from_callable(
lambda plx: self._to_compliant_expr(plx).fill_null(
value=value, strategy=strategy, limit=limit
),
self._metadata,
)
)

# --- partial reduction ---
Expand Down Expand Up @@ -1607,9 +1595,7 @@ def is_unique(self: Self) -> Self:
|3 1 c False True|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_unique(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_unique())

def null_count(self: Self) -> Self:
r"""Count null values.
Expand Down Expand Up @@ -1830,8 +1816,8 @@ def round(self: Self, decimals: int = 0) -> Self:
|2 3.901234 3.9|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).round(decimals), self._metadata
return self._from_callable(
lambda plx: self._to_compliant_expr(plx).round(decimals)
)

def len(self: Self) -> Self:
Expand Down Expand Up @@ -2003,9 +1989,7 @@ def is_finite(self: Self) -> Self:
|β””β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜|
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
return self.__class__(
lambda plx: self._to_compliant_expr(plx).is_finite(), self._metadata
)
return self._from_callable(lambda plx: self._to_compliant_expr(plx).is_finite())

def cum_count(self: Self, *, reverse: bool = False) -> Self:
r"""Return the cumulative count of the non-null values in the column.
Expand Down
Loading