Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
kwargs: dict[str, Any] | None = None,
) -> None:
self._call = call
self._depth = depth
Expand All @@ -53,7 +53,7 @@ def __init__(
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._kwargs = kwargs
self._kwargs = kwargs or {}

def __repr__(self: Self) -> str: # pragma: no cover
return f"ArrowExpr(depth={self._depth}, function_name={self._function_name}, "
Expand Down Expand Up @@ -117,7 +117,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
)

@classmethod
Expand Down Expand Up @@ -148,7 +147,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
)

def __narwhals_namespace__(self: Self) -> ArrowNamespace:
Expand Down Expand Up @@ -315,7 +313,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
alias_output_names=alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "name": name},
kwargs=self._kwargs,
)

def null_count(self: Self) -> Self:
Expand Down
17 changes: 3 additions & 14 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _create_expr_from_callable(
function_name: str,
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
kwargs: dict[str, Any],
kwargs: dict[str, Any] | None = None,
) -> ArrowExpr:
from narwhals._arrow.expr import ArrowExpr

Expand All @@ -75,7 +75,6 @@ def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr:
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def _create_series_from_scalar(
Expand Down Expand Up @@ -142,7 +141,6 @@ def len(self: Self) -> ArrowExpr:
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def all(self: Self) -> ArrowExpr:
Expand All @@ -165,7 +163,6 @@ def all(self: Self) -> ArrowExpr:
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
Expand All @@ -188,7 +185,6 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
Expand All @@ -202,7 +198,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
kwargs={"exprs": exprs},
)

def any_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
Expand Down Expand Up @@ -387,11 +382,6 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
function_name="concat_str",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
kwargs={
"exprs": exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)


Expand Down Expand Up @@ -466,7 +456,6 @@ def then(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen:
alias_output_names=getattr(value, "_alias_output_names", None),
backend_version=self._backend_version,
version=self._version,
kwargs={"value": value},
)


Expand All @@ -481,7 +470,7 @@ def __init__(
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
kwargs: dict[str, Any] | None = None,
) -> None:
self._backend_version = backend_version
self._version = version
Expand All @@ -490,7 +479,7 @@ def __init__(
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._kwargs = kwargs
self._kwargs = kwargs or {}

def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr:
# type ignore because we are setting the `_call` attribute to a
Expand Down
23 changes: 7 additions & 16 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return selector(self, func, evaluate_output_names, {"dtypes": dtypes})
return selector(self, func, evaluate_output_names)

def matches(self: Self, pattern: str) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand All @@ -45,7 +45,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if re.search(pattern, col)]

return selector(self, func, evaluate_output_names, {"pattern": pattern})
return selector(self, func, evaluate_output_names)

def numeric(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
Expand Down Expand Up @@ -82,7 +82,7 @@ def all(self: Self) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [df[col] for col in df.columns]

return selector(self, func, lambda df: df.columns, {})
return selector(self, func, lambda df: df.columns)

def datetime(
self: Self,
Expand Down Expand Up @@ -118,7 +118,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
)
]

return selector(self, func, evaluate_output_names, {})
return selector(self, func, evaluate_output_names)


class ArrowSelector(ArrowExpr):
Expand All @@ -134,7 +134,6 @@ def _to_expr(self: Self) -> ArrowExpr:
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs=self._kwargs,
)

def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any:
Expand All @@ -151,9 +150,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() - other

Expand All @@ -175,9 +172,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() | other

Expand All @@ -195,9 +190,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)

else:
return self._to_expr() & other
Expand All @@ -210,7 +203,6 @@ def selector(
context: _LimitedContext,
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
kwargs: dict[str, Any],
/,
) -> ArrowSelector:
return ArrowSelector(
Expand All @@ -221,5 +213,4 @@ def selector(
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
kwargs=kwargs,
)
18 changes: 9 additions & 9 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,29 +431,29 @@ def over(self: Self, keys: list[str]) -> Self:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])

reverse = self._kwargs.get("reverse", False)
if reverse:
msg = (
"Cumulative operation with `reverse=True` is not supported in "
"over context for pandas-like backend."
)
raise NotImplementedError(msg)

unsupported_reverse_msg = (
"Cumulative operation with `reverse=True` is not supported in "
"over context for pandas-like backend."
)
if function_name == "cum_count":
if self._kwargs["reverse"]:
raise NotImplementedError(unsupported_reverse_msg)
plx = self.__narwhals_namespace__()
df = df.with_columns(~plx.col(*output_names).is_null())

if function_name == "shift":
kwargs = {"periods": self._kwargs["n"]}
elif function_name == "rank":
_method = self._kwargs.get("method", "average")
_method = self._kwargs["method"]
kwargs = {
"method": "first" if _method == "ordinal" else _method,
"ascending": not self._kwargs["descending"],
"na_option": "keep",
"pct": False,
}
else: # Cumulative operation
if self._kwargs["reverse"]:
raise NotImplementedError(unsupported_reverse_msg)
kwargs = {"skipna": True}

res_native = getattr(
Expand Down
23 changes: 8 additions & 15 deletions narwhals/_pandas_like/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return selector(self, func, evaluate_output_names, {"dtypes": dtypes})
return selector(self, func, evaluate_output_names)

def matches(self: Self, pattern: str) -> PandasSelector:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
Expand All @@ -46,7 +46,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
return [col for col in df.columns if re.search(pattern, col)]

return selector(self, func, evaluate_output_names, {"pattern": pattern})
return selector(self, func, evaluate_output_names)

def numeric(self: Self) -> PandasSelector:
dtypes = import_dtypes_module(self._version)
Expand Down Expand Up @@ -83,7 +83,7 @@ def all(self: Self) -> PandasSelector:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
return [df[col] for col in df.columns]

return selector(self, func, lambda df: df.columns, {})
return selector(self, func, lambda df: df.columns)
Copy link
Member

@dangotbanned dangotbanned Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to follow-up with replacing these lambdas:

diff --git a/narwhals/utils.py b/narwhals/utils.py
index cb33a603..05b9cc60 100644
--- a/narwhals/utils.py
+++ b/narwhals/utils.py
@@ -10,6 +10,7 @@ from secrets import token_hex
 from typing import TYPE_CHECKING
 from typing import Any
 from typing import Iterable
+from typing import Iterator
 from typing import Sequence
 from typing import TypeVar
 from typing import Union
@@ -59,6 +60,7 @@ if TYPE_CHECKING:
     from narwhals.typing import DataFrameLike
     from narwhals.typing import DTypes
     from narwhals.typing import IntoSeriesT
+    from narwhals.typing import NativeFrame
     from narwhals.typing import SizeUnit
     from narwhals.typing import SupportsNativeNamespace
     from narwhals.typing import TimeUnit
@@ -1303,6 +1305,14 @@ def dtype_matches_time_unit_and_time_zone(
     )
 
 
+def get_columns(df: NativeFrame) -> Sequence[str]:
+    return df.columns
+
+
+def iter_columns(df: NativeFrame) -> Iterator[str]:
+    yield from df.columns
+
+
 def _hasattr_static(obj: Any, attr: str) -> bool:
     sentinel = object()
     return getattr_static(obj, attr, sentinel) is not sentinel

iter_columns wouldn't be for this part, but would be usuable in lots of places

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks!


def datetime(
self: Self,
Expand Down Expand Up @@ -119,7 +119,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]:
)
]

return selector(self, func, evaluate_output_names, {})
return selector(self, func, evaluate_output_names)


class PandasSelector(PandasLikeExpr):
Expand Down Expand Up @@ -155,9 +155,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() - other

Expand All @@ -179,9 +177,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() | other

Expand All @@ -199,9 +195,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() & other

Expand All @@ -213,7 +207,6 @@ def selector(
context: _FullContext,
call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]],
evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]],
kwargs: dict[str, Any],
/,
) -> PandasSelector:
return PandasSelector(
Expand All @@ -225,5 +218,5 @@ def selector(
implementation=context._implementation,
backend_version=context._backend_version,
version=context._version,
kwargs=kwargs,
kwargs={},
)
Loading