diff --git a/.github/workflows/downstream_tests.yml b/.github/workflows/downstream_tests.yml index 2d47d8d740..c808e00487 100644 --- a/.github/workflows/downstream_tests.yml +++ b/.github/workflows/downstream_tests.yml @@ -489,7 +489,7 @@ jobs: run: | cd validoopsie # empty pytest.ini to avoid pytest using narwhals configs - touch pytest.ini + touch pytest.ini touch tests/__init__.py touch tests/utils/__init__.py uv run pytest tests diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 59ffe7d189..48e2715f0d 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -381,7 +381,9 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: return self._with_native(native_frame, validate_column_names=False) - def group_by(self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool + ) -> ArrowGroupBy: from narwhals._arrow.group_by import ArrowGroupBy return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys) diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index ed83589c41..a9aa38bd4b 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -176,7 +176,7 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: ) raise NotImplementedError(msg) - tmp = df.group_by(*partition_by, drop_null_keys=False).agg(self) + tmp = df.group_by(partition_by, drop_null_keys=False).agg(self) tmp = df.simple_select(*partition_by).join( tmp, how="left", diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 63741da558..6a13385a34 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -40,28 +40,28 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]): def __init__( self, - compliant_frame: ArrowDataFrame, - keys: Sequence[str], + df: ArrowDataFrame, + keys: Sequence[ArrowExpr] | Sequence[str], /, *, drop_null_keys: bool, ) -> None: - if drop_null_keys: - self._compliant_frame = compliant_frame.drop_nulls(keys) - else: - self._compliant_frame = compliant_frame - self._keys: list[str] = list(keys) + self._df = df + frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys) + self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame self._grouped = pa.TableGroupBy(self.compliant.native, self._keys) + self._drop_null_keys = drop_null_keys def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: self._ensure_all_simple(exprs) aggs: list[tuple[str, str, Any]] = [] expected_pyarrow_column_names: list[str] = self._keys.copy() new_column_names: list[str] = self._keys.copy() + exclude = (*self._keys, *self._output_key_names) for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self.compliant, self._keys + expr, self.compliant, exclude ) if expr._depth == 0: @@ -120,7 +120,10 @@ def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame: result_simple = result_simple.select( [*self._keys, *[col for col in columns if col not in self._keys]] ) - return self.compliant._with_native(result_simple) + + return self.compliant._with_native(result_simple).rename( + dict(zip(self._keys, self._output_key_names)) + ) def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: col_token = generate_temporary_column_name( @@ -142,9 +145,13 @@ def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: null_replacement=null_token, ) table = table.add_column(i=0, field_=col_token, column=key_values) + for v in pc.unique(key_values): t = self.compliant._with_native( table.filter(pc.equal(table[col_token], v)).drop([col_token]) ) row = t.simple_select(*self._keys).row(0) - yield tuple(extract_py_scalar(el) for el in row), t + yield ( + tuple(extract_py_scalar(el) for el in row), + t.simple_select(*self._df.columns), + ) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index f218db68c9..b785eba71d 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -18,7 +18,6 @@ from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import NativeFrameT from narwhals._compliant.typing import NativeSeriesT -from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._translate import ArrowConvertible from narwhals._translate import DictConvertible from narwhals._translate import FromNative @@ -159,7 +158,10 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... def gather_every(self, n: int, offset: int) -> Self: ... def get_column(self, name: str) -> CompliantSeriesT: ... def group_by( - self, *keys: str, drop_null_keys: bool + self, + keys: Sequence[str] | Sequence[CompliantExprT_contra], + *, + drop_null_keys: bool, ) -> DataFrameGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def item(self, row: int | None, column: int | str | None) -> Any: ... @@ -250,6 +252,10 @@ def write_csv(self, file: str | Path | BytesIO) -> None: ... def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ... def write_parquet(self, file: str | Path | BytesIO) -> None: ... + def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: + it = (expr._evaluate_aliases(self) for expr in exprs) + return list(chain.from_iterable(it)) + class CompliantLazyFrame( _StoresNative[NativeFrameT], @@ -302,8 +308,11 @@ def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... ) def gather_every(self, n: int, offset: int) -> Self: ... def group_by( - self, *keys: str, drop_null_keys: bool - ) -> CompliantGroupBy[Self, Any]: ... + self, + keys: Sequence[str] | Sequence[CompliantExprT_contra], + *, + drop_null_keys: bool, + ) -> CompliantGroupBy[Self, CompliantExprT_contra]: ... def head(self, n: int) -> Self: ... def join( self, @@ -349,6 +358,10 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: assert len(result) == 1 # debug assertion # noqa: S101 return result[0] + def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: + it = (expr._evaluate_aliases(self) for expr in exprs) + return list(chain.from_iterable(it)) + class EagerDataFrame( CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT], @@ -379,7 +392,7 @@ def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want. """ - _, aliases = evaluate_output_names_and_aliases(expr, self, []) + aliases = expr._evaluate_aliases(self) result = expr(self) if list(aliases) != ( result_aliases := [s.name for s in result] diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 734c298aa6..183060c8ad 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -30,7 +30,6 @@ from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import LazyExprT from narwhals._compliant.typing import NativeExprT -from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals.dependencies import get_numpy from narwhals.dependencies import is_numpy_array from narwhals.dtypes import DType @@ -195,19 +194,6 @@ def clip( upper_bound: Self | NumericLiteral | TemporalLiteral | None, ) -> Self: ... - @property - def str(self) -> Any: ... - @property - def name(self) -> Any: ... - @property - def dt(self) -> Any: ... - @property - def cat(self) -> Any: ... - @property - def list(self) -> Any: ... - @property - def struct(self) -> Any: ... - def ewm_mean( self, *, @@ -287,6 +273,25 @@ def _is_multi_output_unnamed(self) -> bool: assert self._metadata is not None # noqa: S101 return self._metadata.expansion_kind.is_multi_unnamed() + def _evaluate_aliases( + self: CompliantExpr[CompliantFrameT, Any], frame: CompliantFrameT, / + ) -> Sequence[str]: + names = self._evaluate_output_names(frame) + return alias(names) if (alias := self._alias_output_names) else names + + @property + def str(self) -> Any: ... + @property + def name(self) -> Any: ... + @property + def dt(self) -> Any: ... + @property + def cat(self) -> Any: ... + @property + def list(self) -> Any: ... + @property + def struct(self) -> Any: ... + class DepthTrackingExpr( CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co], @@ -467,7 +472,7 @@ def _reuse_series_inner( series._from_scalar(method(series)) if returns_scalar else method(series) for series in self(df) ] - _, aliases = evaluate_output_names_and_aliases(self, df, []) + aliases = self._evaluate_aliases(df) if [s.name for s in out] != list(aliases): # pragma: no cover msg = ( f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n" diff --git a/narwhals/_compliant/group_by.py b/narwhals/_compliant/group_by.py index 37d37fb259..0c36c8924e 100644 --- a/narwhals/_compliant/group_by.py +++ b/narwhals/_compliant/group_by.py @@ -13,19 +13,27 @@ from typing import Sequence from typing import TypeVar +from narwhals._compliant.typing import CompliantDataFrameAny +from narwhals._compliant.typing import CompliantDataFrameT from narwhals._compliant.typing import CompliantDataFrameT_co from narwhals._compliant.typing import CompliantExprT_contra +from narwhals._compliant.typing import CompliantFrameT from narwhals._compliant.typing import CompliantFrameT_co -from narwhals._compliant.typing import CompliantLazyFrameT_co +from narwhals._compliant.typing import CompliantLazyFrameAny +from narwhals._compliant.typing import CompliantLazyFrameT from narwhals._compliant.typing import DepthTrackingExprAny from narwhals._compliant.typing import DepthTrackingExprT_contra from narwhals._compliant.typing import EagerExprT_contra from narwhals._compliant.typing import LazyExprT_contra from narwhals._compliant.typing import NativeExprT_co +from narwhals._expression_parsing import is_multi_output +from narwhals.utils import is_sequence_of if TYPE_CHECKING: from typing_extensions import TypeAlias + _SameFrameT = TypeVar("_SameFrameT", CompliantDataFrameAny, CompliantLazyFrameAny) + if not TYPE_CHECKING: # pragma: no cover if sys.version_info >= (3, 9): @@ -58,7 +66,6 @@ class CompliantGroupBy(Protocol38[CompliantFrameT_co, CompliantExprT_contra]): _compliant_frame: Any - _keys: Sequence[str] @property def compliant(self) -> CompliantFrameT_co: @@ -67,7 +74,7 @@ def compliant(self) -> CompliantFrameT_co: def __init__( self, compliant_frame: CompliantFrameT_co, - keys: Sequence[str], + keys: Sequence[CompliantExprT_contra] | Sequence[str], /, *, drop_null_keys: bool, @@ -83,9 +90,60 @@ class DataFrameGroupBy( def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ... +class ParseKeysGroupBy( + CompliantGroupBy[CompliantFrameT, CompliantExprT_contra], + Protocol38[CompliantFrameT, CompliantExprT_contra], +): + def _parse_keys( + self, + compliant_frame: CompliantFrameT, + keys: Sequence[CompliantExprT_contra] | Sequence[str], + ) -> tuple[CompliantFrameT, list[str], list[str]]: + if is_sequence_of(keys, str): + keys_str = list(keys) + return compliant_frame, keys_str, keys_str.copy() + else: + return self._parse_expr_keys(compliant_frame, keys=keys) + + @staticmethod + def _parse_expr_keys( + compliant_frame: _SameFrameT, keys: Sequence[CompliantExprT_contra] + ) -> tuple[_SameFrameT, list[str], list[str]]: + """Parses key expressions to set up `.agg` operation with correct information. + + Since keys are expressions, it's possible to alias any such key to match + other dataframe column names. + + In order to match polars behavior and not overwrite columns when evaluating keys: + + - We evaluate what the output key names should be, in order to remap temporary column + names to the expected ones, and to exclude those from unnamed expressions in + `.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520) + - Create temporary names for evaluated key expressions that are guaranteed to have + no overlap with any existing column name. + - Add these temporary columns to the compliant dataframe. + """ + suffix_token = "_" * (max(len(str(c)) for c in compliant_frame.columns) + 1) + output_names = compliant_frame._evaluate_aliases(*keys) + + safe_keys = [ + # multi-output expression cannot have duplicate names, hence it's safe to suffix + key.name.suffix(suffix_token) + if key._metadata is not None and is_multi_output(key._metadata.expansion_kind) + # otherwise it's single named and we can use Expr.alias + else key.alias(f"{new_name}{suffix_token}") + for key, new_name in zip(keys, output_names) + ] + return ( + compliant_frame.with_columns(*safe_keys), + compliant_frame._evaluate_aliases(*safe_keys), + output_names, + ) + + class DepthTrackingGroupBy( - CompliantGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra], - Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co], + ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra], + Protocol38[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co], ): """`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`.""" @@ -138,16 +196,20 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any: class EagerGroupBy( - DepthTrackingGroupBy[CompliantDataFrameT_co, EagerExprT_contra, str], - DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra], - Protocol38[CompliantDataFrameT_co, EagerExprT_contra], + DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, str], + DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra], + Protocol38[CompliantDataFrameT, EagerExprT_contra], ): ... class LazyGroupBy( - CompliantGroupBy[CompliantLazyFrameT_co, LazyExprT_contra], - Protocol38[CompliantLazyFrameT_co, LazyExprT_contra, NativeExprT_co], + ParseKeysGroupBy[CompliantLazyFrameT, LazyExprT_contra], + CompliantGroupBy[CompliantLazyFrameT, LazyExprT_contra], + Protocol38[CompliantLazyFrameT, LazyExprT_contra, NativeExprT_co], ): + _keys: list[str] + _output_key_names: list[str] + def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: output_names = expr._evaluate_output_names(self.compliant) aliases = ( @@ -157,8 +219,9 @@ def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]: ) native_exprs = expr(self.compliant) if expr._is_multi_output_unnamed(): + exclude = {*self._keys, *self._output_key_names} for native_expr, name, alias in zip(native_exprs, output_names, aliases): - if name not in self._keys: + if name not in exclude: yield native_expr.alias(alias) else: for native_expr, alias in zip(native_exprs, aliases): diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 2f1cc7842e..e32f4656c2 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -307,7 +307,7 @@ def _with_native( """Return a new `CompliantSeries`, wrapping the native `series`. In cases when operations are known to not affect whether a result should - be broadcast, we can pass `preverse_broadcast=True`. + be broadcast, we can pass `preserve_broadcast=True`. Set this with care - it should only be set for unary expressions which don't change length or order, such as `.alias` or `.fill_null`. If in doubt, don't set it, you probably don't need it. diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index b78bf8d3c7..44e27dff3b 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -406,10 +406,12 @@ def join_asof( ), ) - def group_by(self, *by: str, drop_null_keys: bool) -> DaskLazyGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool + ) -> DaskLazyGroupBy: from narwhals._dask.group_by import DaskLazyGroupBy - return DaskLazyGroupBy(self, by, drop_null_keys=drop_null_keys) + return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) def tail(self, n: int) -> Self: # pragma: no cover native_frame = self.native diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 195d2414f0..2a23547da7 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -70,12 +70,18 @@ class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregat } def __init__( - self, df: DaskLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool + self, + df: DaskLazyFrame, + keys: Sequence[DaskExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - self._compliant_frame = df - self._keys: list[str] = list(keys) + self._compliant_frame, self._keys, self._output_key_names = self._parse_keys( + df, keys=keys + ) self._grouped = self.compliant.native.groupby( - list(self._keys), dropna=drop_null_keys, observed=True + self._keys, dropna=drop_null_keys, observed=True ) def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: @@ -83,17 +89,21 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: if not exprs: # No aggregation provided - return self.compliant.simple_select(*self._keys).unique( - self._keys, keep="any" + return ( + self.compliant.simple_select(*self._keys) + .unique(self._keys, keep="any") + .rename(dict(zip(self._keys, self._output_key_names))) ) + self._ensure_all_simple(exprs) # This should be the fastpath, but cuDF is too far behind to use it. # - https://github.com/rapidsai/cudf/issues/15118 # - https://github.com/rapidsai/cudf/issues/15084 simple_aggregations: dict[str, tuple[str, Aggregation]] = {} + exclude = (*self._keys, *self._output_key_names) for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self.compliant, self._keys + expr, self.compliant, exclude ) if expr._depth == 0: # e.g. `agg(nw.len())` @@ -114,4 +124,4 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame: self._grouped.agg(**simple_aggregations).reset_index(), backend_version=self.compliant._backend_version, version=self.compliant._version, - ) + ).rename(dict(zip(self._keys, self._output_key_names))) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index fe90b0353c..475154a106 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -4,7 +4,6 @@ from typing import Any from typing import Sequence -from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import select_columns_by_name from narwhals.dependencies import get_pandas from narwhals.dependencies import get_pyarrow @@ -40,8 +39,8 @@ def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]: native_results: list[tuple[str, dx.Series]] = [] for expr in exprs: - native_series_list = expr._call(df) - _, aliases = evaluate_output_names_and_aliases(expr, df, []) + native_series_list = expr(df) + aliases = expr._evaluate_aliases(df) if len(aliases) != len(native_series_list): # pragma: no cover msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" raise AssertionError(msg) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 6e306bb12b..1d2cff383a 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -248,7 +248,9 @@ def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self: df, backend_version=self._backend_version, version=self._version ) - def group_by(self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool + ) -> DuckDBGroupBy: from narwhals._duckdb.group_by import DuckDBGroupBy return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 1a3479d3fb..981d7a7389 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -345,13 +345,7 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: raise ValueError(msg) return [name] - return self.__class__( - self._call, - evaluate_output_names=self._evaluate_output_names, - alias_output_names=alias_output_names, - backend_version=self._backend_version, - version=self._version, - ) + return self._with_alias_output_names(alias_output_names) def abs(self) -> Self: return self._with_callable(lambda _input: FunctionExpression("abs", _input)) diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index 156f688c7c..991d6993ed 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -17,16 +17,16 @@ class DuckDBGroupBy(LazyGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]): def __init__( self, df: DuckDBLazyFrame, - keys: Sequence[str], + keys: Sequence[DuckDBExpr] | Sequence[str], /, *, drop_null_keys: bool, ) -> None: - self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df - self._keys = list(keys) + frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys) + self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs))) return self.compliant._with_native( self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type] - ) + ).rename(dict(zip(self._keys, self._output_key_names))) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 192b02f915..7ee2878012 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -96,8 +96,6 @@ def evaluate_output_names_and_aliases( expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str] ) -> tuple[Sequence[str], Sequence[str]]: output_names = expr._evaluate_output_names(df) - if not output_names: - return [], [] aliases = ( output_names if expr._alias_output_names is None diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 9774349249..7030e6304f 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -569,7 +569,9 @@ def collect( raise ValueError(msg) # pragma: no cover # --- actions --- - def group_by(self, *keys: str, drop_null_keys: bool) -> PandasLikeGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[PandasLikeExpr], *, drop_null_keys: bool + ) -> PandasLikeGroupBy: from narwhals._pandas_like.group_by import PandasLikeGroupBy return PandasLikeGroupBy(self, keys, drop_null_keys=drop_null_keys) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index ef3db6ffef..eb2b3cb5f8 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -38,19 +38,22 @@ class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]): def __init__( self, df: PandasLikeDataFrame, - keys: Sequence[str], + keys: Sequence[PandasLikeExpr] | Sequence[str], /, *, drop_null_keys: bool, ) -> None: - self._compliant_frame = df - self._keys: list[str] = list(keys) + self._df = df + self._drop_null_keys = drop_null_keys + self._compliant_frame, self._keys, self._output_key_names = self._parse_keys( + df, keys=keys + ) # Drop index to avoid potential collisions: # https://github.com/narwhals-dev/narwhals/issues/1907. - if set(df.native.index.names).intersection(df.columns): - native_frame = df.native.reset_index(drop=True) + if set(self.compliant.native.index.names).intersection(self.compliant.columns): + native_frame = self.compliant.native.reset_index(drop=True) else: - native_frame = df.native + native_frame = self.compliant.native if ( self.compliant._implementation.is_pandas() and self.compliant._backend_version < (1, 1) @@ -82,10 +85,9 @@ def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 new_names: list[str] = self._keys.copy() all_aggs_are_simple = True + exclude = (*self._keys, *self._output_key_names) for expr in exprs: - _, aliases = evaluate_output_names_and_aliases( - expr, self.compliant, self._keys - ) + _, aliases = evaluate_output_names_and_aliases(expr, self.compliant, exclude) new_names.extend(aliases) if not self._is_simple(expr): all_aggs_are_simple = False @@ -111,7 +113,7 @@ def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 if all_aggs_are_simple: for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases( - expr, self.compliant, self._keys + expr, self.compliant, exclude ) if expr._depth == 0: # e.g. `agg(nw.len())` @@ -242,7 +244,7 @@ def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 result.reset_index(inplace=True) # noqa: PD002 return self.compliant._with_native( select_columns_by_name(result, new_names, backend_version, implementation) - ) + ).rename(dict(zip(self._keys, self._output_key_names))) if self.compliant.native.empty: # Don't even attempt this, it's way too inconsistent across pandas versions. @@ -287,12 +289,11 @@ def func(df: Any) -> Any: # Keep inplace=True to avoid making a redundant copy. # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files result_complex.reset_index(inplace=True) # noqa: PD002 - return self.compliant._with_native( select_columns_by_name( result_complex, new_names, backend_version, implementation ) - ) + ).rename(dict(zip(self._keys, self._output_key_names))) def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: with warnings.catch_warnings(): @@ -301,5 +302,9 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: message=".*a length 1 tuple will be returned", category=FutureWarning, ) + for key, group in self._grouped: - yield (key, self.compliant._with_native(group)) + yield ( + key, + self.compliant._with_native(group).simple_select(*self._df.columns), + ) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index fe9d3aa140..e9a25fecfc 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -44,6 +44,7 @@ from typing_extensions import TypeAlias from typing_extensions import TypeIs + from narwhals._polars.expr import PolarsExpr from narwhals._polars.group_by import PolarsGroupBy from narwhals._polars.group_by import PolarsLazyGroupBy from narwhals._translate import IntoArrowTable @@ -97,12 +98,11 @@ class PolarsDataFrame: write_csv: Method[Any] write_parquet: Method[None] + # CompliantDataFrame + _evaluate_aliases: Any + def __init__( - self, - df: pl.DataFrame, - *, - backend_version: tuple[int, ...], - version: Version, + self, df: pl.DataFrame, *, backend_version: tuple[int, ...], version: Version ) -> None: self._native_frame = df self._backend_version = backend_version @@ -403,7 +403,9 @@ def to_dict( else: return self.native.to_dict(as_series=False) - def group_by(self, *keys: str, drop_null_keys: bool) -> PolarsGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool + ) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys) @@ -507,15 +509,13 @@ class PolarsLazyFrame: tail: Method[Self] unique: Method[Self] with_columns: Method[Self] - # NOTE: Temporary, just trying to factor out utils + + # CompliantLazyFrame _evaluate_expr: Any + _evaluate_aliases: Any def __init__( - self, - df: pl.LazyFrame, - *, - backend_version: tuple[int, ...], - version: Version, + self, df: pl.LazyFrame, *, backend_version: tuple[int, ...], version: Version ) -> None: self._native_frame = df self._backend_version = backend_version @@ -651,7 +651,9 @@ def collect( msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise ValueError(msg) # pragma: no cover - def group_by(self, *keys: str, drop_null_keys: bool) -> PolarsLazyGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool + ) -> PolarsLazyGroupBy: from narwhals._polars.group_by import PolarsLazyGroupBy return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index ba8f27b3ee..c6e41205e5 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -256,6 +256,7 @@ def struct(self) -> PolarsExprStructNamespace: # CompliantExpr _alias_output_names: Any + _evaluate_aliases: Any _evaluate_output_names: Any _is_multi_output_unnamed: Any __call__: Any diff --git a/narwhals/_polars/group_by.py b/narwhals/_polars/group_by.py index 9152b0cf3a..511f57b184 100644 --- a/narwhals/_polars/group_by.py +++ b/narwhals/_polars/group_by.py @@ -5,7 +5,7 @@ from typing import Sequence from typing import cast -from narwhals._polars.utils import extract_native +from narwhals.utils import is_sequence_of if TYPE_CHECKING: from polars.dataframe.group_by import GroupBy as NativeGroupBy @@ -18,23 +18,33 @@ class PolarsGroupBy: _compliant_frame: PolarsDataFrame - _keys: Sequence[str] + _grouped: NativeGroupBy + _drop_null_keys: bool + _output_names: Sequence[str] @property def compliant(self) -> PolarsDataFrame: return self._compliant_frame def __init__( - self, df: PolarsDataFrame, keys: Sequence[str], /, *, drop_null_keys: bool + self, + df: PolarsDataFrame, + keys: Sequence[PolarsExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - self._compliant_frame = df self._keys = list(keys) - df = df.drop_nulls(keys) if drop_null_keys else df - self._grouped: NativeGroupBy = df._native_frame.group_by(keys) + self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df + self._grouped = ( + self.compliant.native.group_by(keys) + if is_sequence_of(keys, str) + else self.compliant.native.group_by(arg.native for arg in keys) + ) def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame: - from_native = self.compliant._with_native - return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) + agg_result = self._grouped.agg(arg.native for arg in aggs) + return self.compliant._with_native(agg_result) def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: for key, df in self._grouped: @@ -43,20 +53,30 @@ def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: class PolarsLazyGroupBy: _compliant_frame: PolarsLazyFrame - _keys: Sequence[str] + _grouped: NativeLazyGroupBy + _drop_null_keys: bool + _output_names: Sequence[str] @property def compliant(self) -> PolarsLazyFrame: return self._compliant_frame def __init__( - self, df: PolarsLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool + self, + df: PolarsLazyFrame, + keys: Sequence[PolarsExpr] | Sequence[str], + /, + *, + drop_null_keys: bool, ) -> None: - self._compliant_frame = df self._keys = list(keys) - df = df.drop_nulls(keys) if drop_null_keys else df - self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys) + self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df + self._grouped = ( + self.compliant.native.group_by(keys) + if is_sequence_of(keys, str) + else self.compliant.native.group_by(arg.native for arg in keys) + ) def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame: - from_native = self.compliant._with_native - return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) + agg_result = self._grouped.agg(arg.native for arg in aggs) + return self.compliant._with_native(agg_result) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index ea8ed4d31e..b1641ff107 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -215,17 +215,18 @@ def __eq__(self, other: object) -> Self: # type: ignore[override] def __ne__(self, other: object) -> Self: # type: ignore[override] return self._with_native(self.native.__ne__(extract_native(other))) + # NOTE: `pyright` is being reasonable here def __ge__(self, other: Any) -> Self: - return self._with_native(self.native.__ge__(extract_native(other))) + return self._with_native(self.native.__ge__(extract_native(other))) # pyright: ignore[reportArgumentType] def __gt__(self, other: Any) -> Self: - return self._with_native(self.native.__gt__(extract_native(other))) + return self._with_native(self.native.__gt__(extract_native(other))) # pyright: ignore[reportArgumentType] def __le__(self, other: Any) -> Self: - return self._with_native(self.native.__le__(extract_native(other))) + return self._with_native(self.native.__le__(extract_native(other))) # pyright: ignore[reportArgumentType] def __lt__(self, other: Any) -> Self: - return self._with_native(self.native.__lt__(extract_native(other))) + return self._with_native(self.native.__lt__(extract_native(other))) # pyright: ignore[reportArgumentType] def __and__(self, other: PolarsSeries | bool | Any) -> Self: return self._with_native(self.native.__and__(extract_native(other))) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 6a2bf3e961..e9af4c01c7 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -22,46 +22,34 @@ from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: - from narwhals._polars.dataframe import PolarsDataFrame - from narwhals._polars.dataframe import PolarsLazyFrame - from narwhals._polars.expr import PolarsExpr - from narwhals._polars.series import PolarsSeries + from typing_extensions import TypeIs + from narwhals.dtypes import DType + from narwhals.utils import _StoresNative T = TypeVar("T") + NativeT = TypeVar( + "NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr" + ) @overload -def extract_native(obj: PolarsDataFrame) -> pl.DataFrame: ... - - -@overload -def extract_native(obj: PolarsLazyFrame) -> pl.LazyFrame: ... - - -@overload -def extract_native(obj: PolarsSeries) -> pl.Series: ... - - -@overload -def extract_native(obj: PolarsExpr) -> pl.Expr: ... - - +def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ... @overload def extract_native(obj: T) -> T: ... +def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T: + return obj.native if _is_compliant_polars(obj) else obj -def extract_native( - obj: PolarsDataFrame | PolarsLazyFrame | PolarsSeries | PolarsExpr | T, -) -> pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr | T: +def _is_compliant_polars( + obj: _StoresNative[NativeT] | Any, +) -> TypeIs[_StoresNative[NativeT]]: from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries - if isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)): - return obj.native - return obj + return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)) def extract_args_kwargs( diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 2ba77db442..9cde2f3b05 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -295,7 +295,9 @@ def drop(self, columns: Sequence[str], *, strict: bool) -> Self: def head(self, n: int) -> Self: return self._with_native(self.native.limit(num=n)) - def group_by(self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: + def group_by( + self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool + ) -> SparkLikeLazyGroupBy: from narwhals._spark_like.group_by import SparkLikeLazyGroupBy return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 22a6ef4a38..b5cb5003f5 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -16,19 +16,21 @@ class SparkLikeLazyGroupBy(LazyGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "C def __init__( self, df: SparkLikeLazyFrame, - keys: Sequence[str], + keys: Sequence[SparkLikeExpr] | Sequence[str], /, *, drop_null_keys: bool, ) -> None: - self._compliant_frame = df.drop_nulls(subset=None) if drop_null_keys else df - self._keys = list(keys) + frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys) + self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: - if agg_columns := list(self._evaluate_exprs(exprs)): - return self.compliant._with_native( - self.compliant.native.groupBy(*self._keys).agg(*agg_columns) - ) - return self.compliant._with_native( - self.compliant.native.select(*self._keys).dropDuplicates() + result = ( + self.compliant.native.groupBy(*self._keys).agg(*agg_columns) + if (agg_columns := list(self._evaluate_exprs(exprs))) + else self.compliant.native.select(*self._keys).dropDuplicates() + ) + + return self.compliant._with_native(result).rename( + dict(zip(self._keys, self._output_key_names)) ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 2ab5c354c4..f5a9f5493e 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -57,7 +57,7 @@ from narwhals._compliant import CompliantDataFrame from narwhals._compliant import CompliantLazyFrame - from narwhals._compliant import IntoCompliantExpr + from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import EagerNamespaceAny from narwhals.group_by import GroupBy from narwhals.group_by import LazyGroupBy @@ -104,7 +104,7 @@ def _with_compliant(self, df: Any) -> Self: def _flatten_and_extract( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr - ) -> tuple[list[IntoCompliantExpr[Any, Any]], list[ExprKind]]: + ) -> tuple[list[CompliantExprAny], list[ExprKind]]: """Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names.""" out_exprs = [] out_kinds = [] @@ -1510,13 +1510,24 @@ def filter( """ return super().filter(*predicates, **constraints) + @overload + def group_by( + self, *keys: IntoExpr | Iterable[IntoExpr], drop_null_keys: Literal[False] = ... + ) -> GroupBy[Self]: ... + + @overload + def group_by( + self, *keys: str | Iterable[str], drop_null_keys: Literal[True] + ) -> GroupBy[Self]: ... + def group_by( - self, *keys: str | Iterable[str], drop_null_keys: bool = False + self, *keys: IntoExpr | Iterable[IntoExpr], drop_null_keys: bool = False ) -> GroupBy[Self]: r"""Start a group by operation. Arguments: - *keys: Column(s) to group by. Accepts multiple columns names as a list. + *keys: Column(s) to group by. Accepts expression input. Strings are parsed as + column names. drop_null_keys: if True, then groups where any key is null won't be included in the result. @@ -1558,19 +1569,47 @@ def group_by( 1 b 2 4 2 b 3 2 3 c 3 1 + + Expressions are also accepted. + + >>> nw.from_native(df_native, eager_only=True).group_by( + ... "a", nw.col("b") // 2 + ... ).agg(nw.col("c").mean()).to_native() + a b c + 0 a 0 4.0 + 1 b 1 3.0 + 2 c 1 1.0 """ - from narwhals.expr import Expr from narwhals.group_by import GroupBy - from narwhals.series import Series flat_keys = flatten(keys) - if any(isinstance(x, (Expr, Series)) for x in flat_keys): - msg = ( - "`group_by` with expression or Series keys is not (yet?) supported.\n\n" - "Hint: instead of `df.group_by(nw.col('a'))`, use `df.group_by('a')`." - ) + + if all(isinstance(key, str) for key in flat_keys): + return GroupBy(self, flat_keys, drop_null_keys=drop_null_keys) + + from narwhals import col + from narwhals.expr import Expr + from narwhals.series import Series + + key_is_expr_or_series = tuple(isinstance(k, (Expr, Series)) for k in flat_keys) + + if drop_null_keys and any(key_is_expr_or_series): + msg = "drop_null_keys cannot be True when keys contains Expr or Series" raise NotImplementedError(msg) - return GroupBy(self, *flat_keys, drop_null_keys=drop_null_keys) + + _keys = [ + k if is_expr else col(k) + for k, is_expr in zip(flat_keys, key_is_expr_or_series) + ] + expr_flat_keys, kinds = self._flatten_and_extract(*_keys) + + if not all(kind is ExprKind.TRANSFORM for kind in kinds): + from narwhals.exceptions import ComputeError + + msg = "Group by is not supported with keys that are not transformation expressions" + raise ComputeError(msg) + + return GroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( self, @@ -2792,15 +2831,24 @@ def filter( return super().filter(*predicates, **constraints) + @overload + def group_by( + self, *keys: IntoExpr | Iterable[IntoExpr], drop_null_keys: Literal[False] = ... + ) -> LazyGroupBy[Self]: ... + + @overload + def group_by( + self, *keys: str | Iterable[str], drop_null_keys: Literal[True] + ) -> LazyGroupBy[Self]: ... + def group_by( - self, *keys: str | Iterable[str], drop_null_keys: bool = False + self, *keys: IntoExpr | Iterable[IntoExpr], drop_null_keys: bool = False ) -> LazyGroupBy[Self]: r"""Start a group by operation. Arguments: - *keys: - Column(s) to group by. Accepts expression input. Strings are - parsed as column names. + *keys: Column(s) to group by. Accepts expression input. Strings are parsed as + column names. drop_null_keys: if True, then groups where any key is null won't be included in the result. @@ -2823,19 +2871,46 @@ def group_by( │ b │ 2 │ └─────────┴────────┘ + + Expressions are also accepted. + + >>> df.group_by(nw.col("b").str.len_chars()).agg( + ... nw.col("a").sum() + ... ).to_native() + ┌───────┬────────┐ + │ b │ a │ + │ int64 │ int128 │ + ├───────┼────────┤ + │ 1 │ 6 │ + └───────┴────────┘ + """ - from narwhals.expr import Expr from narwhals.group_by import LazyGroupBy - from narwhals.series import Series flat_keys = flatten(keys) - if any(isinstance(x, (Expr, Series)) for x in flat_keys): - msg = ( - "`group_by` with expression or Series keys is not (yet?) supported.\n\n" - "Hint: instead of `df.group_by(nw.col('a'))`, use `df.group_by('a')`." - ) + + if all(isinstance(key, str) for key in flat_keys): + return LazyGroupBy(self, flat_keys, drop_null_keys=drop_null_keys) + + from narwhals import col + from narwhals.expr import Expr + + key_is_expr = tuple(isinstance(k, Expr) for k in flat_keys) + + if drop_null_keys and any(key_is_expr): + msg = "drop_null_keys cannot be True when keys contains Expr" raise NotImplementedError(msg) - return LazyGroupBy(self, *flat_keys, drop_null_keys=drop_null_keys) + + _keys = [k if is_expr else col(k) for k, is_expr in zip(flat_keys, key_is_expr)] + expr_flat_keys, kinds = self._flatten_and_extract(*_keys) + + if not all(kind is ExprKind.TRANSFORM for kind in kinds): + from narwhals.exceptions import ComputeError + + msg = "Group by is not supported with keys that are not transformation expressions" + raise ComputeError(msg) + + return LazyGroupBy(self, expr_flat_keys, drop_null_keys=drop_null_keys) def sort( self, diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 51677b1b7f..02ad83c896 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -5,6 +5,7 @@ from typing import Generic from typing import Iterable from typing import Iterator +from typing import Sequence from typing import TypeVar from narwhals._expression_parsing import all_exprs_are_scalar_like @@ -14,6 +15,7 @@ from narwhals.utils import tupleify if TYPE_CHECKING: + from narwhals._compliant.typing import CompliantExprAny from narwhals.dataframe import LazyFrame from narwhals.expr import Expr @@ -21,11 +23,18 @@ class GroupBy(Generic[DataFrameT]): - def __init__(self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> None: + def __init__( + self, + df: DataFrameT, + keys: Sequence[str] | Sequence[CompliantExprAny], + /, + *, + drop_null_keys: bool, + ) -> None: self._df: DataFrameT = df self._keys = keys self._grouped = self._df._compliant_frame.group_by( - *self._keys, drop_null_keys=drop_null_keys + self._keys, drop_null_keys=drop_null_keys ) def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: @@ -98,11 +107,18 @@ def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: class LazyGroupBy(Generic[LazyFrameT]): - def __init__(self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> None: + def __init__( + self, + df: LazyFrameT, + keys: Sequence[str] | Sequence[CompliantExprAny], + /, + *, + drop_null_keys: bool, + ) -> None: self._df: LazyFrameT = df self._keys = keys self._grouped = self._df._compliant_frame.group_by( - *self._keys, drop_null_keys=drop_null_keys + self._keys, drop_null_keys=drop_null_keys ) def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: diff --git a/narwhals/utils.py b/narwhals/utils.py index 01e3574891..2b7de13e7f 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1302,6 +1302,15 @@ def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]: return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp)) +def is_sequence_of(obj: Any, tp: type[_T]) -> TypeIs[Sequence[_T]]: + # Check if an object is a sequence of `tp`, only sniffing the first element. + return bool( + is_sequence_but_not_str(obj) + and (first := next(iter(obj), None)) + and isinstance(first, tp) + ) + + def find_stacklevel() -> int: """Find the first place in the stack that is not inside narwhals. diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index 31cc607374..8908c536ca 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -26,7 +26,6 @@ {"a": [1, 2, 3], "min": [4, 4, 4]}, ), ], - ids=range(5), ) def test_scalar_reduction_select( constructor: Constructor, @@ -56,7 +55,6 @@ def test_scalar_reduction_select( {"a": [1, 2, 3], "min": [4, 4, 4]}, ), ], - ids=range(5), ) def test_scalar_reduction_with_columns( constructor: Constructor, diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index bed99c2213..3335ce655e 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -12,7 +12,6 @@ from narwhals.exceptions import NarwhalsError from tests.utils import DASK_VERSION from tests.utils import DUCKDB_VERSION -from tests.utils import PANDAS_VERSION from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager @@ -54,17 +53,28 @@ def test_invalid_select(constructor: Constructor, invalid_select: Any) -> None: nw.from_native(constructor({"a": [1, 2, 3]})).select(invalid_select) -def test_select_boolean_cols(request: pytest.FixtureRequest) -> None: - if PANDAS_VERSION < (1, 1): - # bug in old pandas - request.applymarker(pytest.mark.xfail) +def test_select_boolean_cols() -> None: df = nw.from_native(pd.DataFrame({True: [1, 2], False: [3, 4]}), eager_only=True) - result = df.group_by(True).agg(nw.col(False).max()) # type: ignore[arg-type]# noqa: FBT003 + result = df.group_by(True).agg(nw.col(False).max()) # type: ignore[arg-type, call-overload] # noqa: FBT003 assert_equal_data(result.to_dict(as_series=False), {True: [1, 2]}) # type: ignore[dict-item] result = df.select(nw.col([False, True])) # type: ignore[list-item] assert_equal_data(result.to_dict(as_series=False), {True: [1, 2], False: [3, 4]}) # type: ignore[dict-item] +def test_select_boolean_cols_multi_group_by() -> None: + df = nw.from_native( + pd.DataFrame({True: [1, 2], False: [3, 4], 2: [1, 1]}), eager_only=True + ) + result = df.group_by(True, 2).agg(nw.col(False).max()) # type: ignore[arg-type, call-overload] # noqa: FBT003 + assert_equal_data( + result.to_dict(as_series=False), + {True: [1, 2], 2: [1, 1], False: [3, 4]}, # type: ignore[dict-item] + ) + + result = df.select(nw.col([False, True])) # type: ignore[list-item] + assert_equal_data(result.to_dict(as_series=False), {True: [1, 2], False: [3, 4]}) # type: ignore[dict-item] + + def test_comparison_with_list_error_message() -> None: msg = "Expected Series or scalar, got list." with pytest.raises(TypeError, match=msg): diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 84d1a90304..0544e865b0 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from contextlib import nullcontext from typing import Any from typing import Mapping @@ -9,7 +10,10 @@ import pytest import narwhals as nw +from narwhals.exceptions import ComputeError from narwhals.exceptions import InvalidOperationError +from narwhals.exceptions import LengthChangingExprError +from narwhals.exceptions import OrderDependentExprError from tests.utils import PANDAS_VERSION from tests.utils import PYARROW_VERSION from tests.utils import Constructor @@ -20,6 +24,8 @@ df_pandas = pd.DataFrame(data) +POLARS_COLLECT_STREAMING_ENGINE = os.environ.get("NARWHALS_POLARS_NEW_STREAMING", None) + def test_group_by_complex() -> None: expected = {"a": [1, 3], "b": [-3.5, -3.0]} @@ -294,7 +300,11 @@ def test_key_with_nulls_iter( ) -> None: if PANDAS_VERSION < (1, 0) and "pandas_constructor" in str(constructor_eager): pytest.skip("Grouping by null values is not supported in pandas < 1.0.0") - data = {"b": ["4", "5", None, "7"], "a": [1, 2, 3, 4], "c": ["4", "3", None, None]} + data = { + "b": [None, "4", "5", None, "7"], + "a": [None, 1, 2, 3, 4], + "c": [None, "4", "3", None, None], + } result = dict( nw.from_native(constructor_eager(data), eager_only=True) .group_by("b", "c", drop_null_keys=True) @@ -415,12 +425,6 @@ def test_all_kind_of_aggs( assert_equal_data(result, expected) -def test_group_by_expr(constructor: Constructor) -> None: - df = nw.from_native(constructor({"a": [1, 1, 3], "b": [4, 5, 6]})) - with pytest.raises(NotImplementedError, match=r"not \(yet\?\) supported"): - df.group_by(nw.col("a")).agg(nw.col("b").mean()) # type: ignore[arg-type] - - def test_pandas_group_by_index_and_column_overlap() -> None: df = pd.DataFrame( {"a": [1, 1, 2], "b": [4, 5, 6]}, index=pd.Index([0, 1, 2], name="a") @@ -431,7 +435,7 @@ def test_pandas_group_by_index_and_column_overlap() -> None: key, result = next(iter(nw.from_native(df, eager_only=True).group_by("a"))) assert key == (1,) - expected_native = pd.DataFrame({"a": [1, 1], "b": [4, 5]}) + expected_native = pd.DataFrame({"a": [1, 1], "b": [4, 5]}, index=pd.Index([0, 1])) pd.testing.assert_frame_equal(result.to_native(), expected_native) @@ -455,3 +459,139 @@ def test_fancy_functions(constructor: Constructor) -> None: .sort("a") ) assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "aggs", "expected", "sort_by"), + [ + ( + [nw.col("a").abs(), nw.col("a").abs().alias("a_with_alias")], + [nw.col("x").sum()], + {"a": [1, 2], "a_with_alias": [1, 2], "x": [5, 5]}, + ["a"], + ), + ( + [nw.col("a").alias("x")], + [nw.col("x").mean().alias("y")], + {"x": [-1, 1, 2], "y": [4.0, 0.5, 2.5]}, + ["x"], + ), + ( + [nw.col("a")], + [nw.col("a").count().alias("foo-bar"), nw.all().sum()], + {"a": [-1, 1, 2], "foo-bar": [1, 2, 2], "x": [4, 1, 5], "y": [1.5, 0, 0]}, + ["a"], + ), + ( + [nw.col("a", "y").abs()], + [nw.col("x").sum()], + {"a": [1, 1, 2], "y": [0.5, 1.5, 1], "x": [1, 4, 5]}, + ["a", "y"], + ), + ( + [nw.col("a").abs().alias("y")], + [nw.all().sum().name.suffix("c")], + {"y": [1, 2], "ac": [1, 4], "xc": [5, 5]}, + ["y"], + ), + ( + [nw.selectors.by_dtype(nw.Float64()).abs()], + [nw.selectors.numeric().sum()], + {"y": [0.5, 1.0, 1.5], "a": [2, 4, -1], "x": [1, 5, 4]}, + ["y"], + ), + ], + ids=range(6), +) +def test_group_by_expr( + request: pytest.FixtureRequest, + constructor: Constructor, + keys: list[nw.Expr], + aggs: list[nw.Expr], + expected: dict[str, list[Any]], + sort_by: list[str], +) -> None: + request_id = request.node.callspec.id + if ( + POLARS_COLLECT_STREAMING_ENGINE + and request_id.startswith("polars[lazy]") + and request_id.endswith("0") + ): + # Blocked by upstream issue as of polars==1.27.1 + # See: https://github.com/pola-rs/polars/issues/22238 + request.applymarker(pytest.mark.xfail) + + data = { + "a": [1, 1, 2, 2, -1], + "x": [0, 1, 2, 3, 4], + "y": [0.5, -0.5, 1.0, -1.0, 1.5], + } + df = nw.from_native(constructor(data)) + result = df.group_by(*keys).agg(*aggs).sort(*sort_by) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("keys", "lazy_context"), + [ + ( + [nw.col("a").drop_nulls()], + pytest.raises(LengthChangingExprError), + ), # Filtration + ( + [nw.col("a").alias("foo"), nw.col("a").drop_nulls()], + pytest.raises(LengthChangingExprError), + ), # Transform and Filtration + ( + [nw.col("a").alias("foo"), nw.col("a").max()], + pytest.raises(ComputeError), + ), # Transform and Aggregation + ( + [nw.col("a").alias("foo"), nw.col("a").cum_max()], + pytest.raises(OrderDependentExprError), + ), # Transform and Window + ([nw.lit(42)], pytest.raises(ComputeError)), # Literal + ], +) +def test_group_by_raise_if_not_transform( + constructor: Constructor, keys: list[nw.Expr], lazy_context: Any +) -> None: + data = {"a": [1, 2, 2, None], "b": [0, 1, 2, 3], "x": [1, 2, 3, 4]} + df = nw.from_native(constructor(data)) + + context: Any = ( + lazy_context if isinstance(df, nw.LazyFrame) else pytest.raises(ComputeError) + ) + with context: + df.group_by(keys).agg(nw.col("x").max()) + + +@pytest.mark.parametrize( + "keys", [[nw.col("a").abs()], ["a", nw.col("a").abs().alias("a_test")]] +) +def test_group_by_raise_drop_null_keys_with_exprs( + constructor: Constructor, keys: list[nw.Expr | str] +) -> None: + data = { + "a": [1, 1, 2, 2, -1], + "x": [0, 1, 2, 3, 4], + "y": [0.5, -0.5, 1.0, -1.0, 1.5], + } + df = nw.from_native(constructor(data)) + with pytest.raises( + NotImplementedError, + match="drop_null_keys cannot be True when keys contains Expr", + ): + df.group_by(*keys, drop_null_keys=True) # type: ignore[call-overload] + + +def test_group_by_selector(constructor: Constructor) -> None: + data = {"a": [1, 1, 1], "b": [4, 4, 6], "c": [7.5, 8.5, 9.0]} + result = ( + nw.from_native(constructor(data)) + .group_by(nw.selectors.by_dtype(nw.Int64)) + .agg(nw.col("c").mean()) + .sort("a", "b") + ) + expected = {"a": [1, 1], "b": [4, 6], "c": [8.0, 9.0]} + assert_equal_data(result, expected)