Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from typing_extensions import NotRequired, Self, TypedDict

from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.expr import CompliantExpr, EagerExpr
from narwhals._compliant.namespace import EagerNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
Expand Down Expand Up @@ -96,6 +97,8 @@ def to_narwhals(self) -> Series[NativeSeriesT]:
def _with_native(self, series: Any) -> Self: ...
def _with_version(self, version: Version) -> Self: ...

def _to_expr(self) -> CompliantExpr[Any, Self]: ...

# NOTE: `polars`
@property
def dtype(self) -> DType: ...
Expand Down Expand Up @@ -243,6 +246,9 @@ def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...

def _to_expr(self) -> EagerExpr[Any, Any]:
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]

def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
Expand Down
18 changes: 1 addition & 17 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from __future__ import annotations

from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar

from narwhals._utils import is_compliant_expr, zip_strict
from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d
from narwhals.dependencies import is_narwhals_series, is_numpy_array
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError

if TYPE_CHECKING:
Expand Down Expand Up @@ -46,13 +45,6 @@ def is_series(obj: Any) -> TypeIs[Series[Any]]:
return isinstance(obj, Series)


def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]:
from narwhals.expr import Expr
from narwhals.series import Series

return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj)


def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
Expand Down Expand Up @@ -584,14 +576,6 @@ def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> No
raise InvalidOperationError(msg)


def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
# Raise if any argument in `args` isn't an aggregation or literal.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
exprs = chain(args, kwargs.values())
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)


def apply_n_ary_operation(
plx: CompliantNamespaceAny,
n_ary_function: Callable[..., CompliantExprAny],
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import polars as pl

from narwhals._polars.expr import PolarsExpr
from narwhals._polars.utils import (
BACKEND_VERSION,
SERIES_ACCEPTS_PD_INDEX,
Expand Down Expand Up @@ -150,6 +151,10 @@ def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series = series
self._version = version

def _to_expr(self) -> PolarsExpr:
# Polars can treat Series as Expr, so just pass down `self.native`.
return PolarsExpr(self.native, version=self._version) # type: ignore[arg-type]

@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
Expand Down
106 changes: 56 additions & 50 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from narwhals._expression_parsing import (
ExprKind,
check_expressions_preserve_length,
is_into_expr_eager,
is_expr,
is_scalar_like,
is_series,
)
from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl
from narwhals._utils import (
Expand All @@ -43,14 +44,14 @@
supports_arrow_c_stream,
zip_strict,
)
from narwhals.dependencies import is_numpy_array_2d, is_pyarrow_table
from narwhals.dependencies import is_numpy_array_1d, is_numpy_array_2d, is_pyarrow_table
from narwhals.exceptions import (
ColumnNotFoundError,
InvalidIntoExprError,
InvalidOperationError,
PerformanceWarning,
)
from narwhals.functions import _from_dict_no_backend, _is_into_schema
from narwhals.functions import _from_dict_no_backend, _is_into_schema, col, new_series
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.translate import to_native
Expand All @@ -67,9 +68,10 @@
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias

from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame
from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny
from narwhals._compliant.typing import CompliantExprAny
from narwhals._translate import IntoArrowTable
from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars
from narwhals.expr import Expr
from narwhals.group_by import GroupBy, LazyGroupBy
from narwhals.typing import (
AsofJoinStrategy,
Expand All @@ -87,6 +89,7 @@
SingleIndexSelector,
SizeUnit,
UniqueKeepStrategy,
_1DArray,
_2DArray,
)

Expand Down Expand Up @@ -148,18 +151,21 @@ def _flatten_and_extract(
# NOTE: Strings are interpreted as column names.
out_exprs = []
out_kinds = []
for expr in flatten(exprs):
compliant_expr = self._extract_compliant(expr)
out_exprs.append(compliant_expr)
out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False))
for alias, expr in named_exprs.items():
compliant_expr = self._extract_compliant(expr).alias(alias)
out_exprs.append(compliant_expr)
out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False))
ns = self.__narwhals_namespace__()
all_exprs = chain(
(self._parse_into_expr(x) for x in flatten(exprs)),
(
self._parse_into_expr(expr).alias(alias)
for alias, expr in named_exprs.items()
),
)
for expr in all_exprs:
out_exprs.append(expr._to_compliant_expr(ns))
out_kinds.append(ExprKind.from_expr(expr))
return out_exprs, out_kinds

@abstractmethod
def _extract_compliant(self, arg: Any) -> Any:
def _parse_into_expr(self, arg: Any) -> Expr:
raise NotImplementedError

def _extract_compliant_frame(self, other: Self | Any, /) -> Any:
Expand Down Expand Up @@ -476,10 +482,15 @@ class DataFrame(BaseFrame[DataFrameT]):
def _compliant(self) -> CompliantDataFrame[Any, Any, DataFrameT, Self]:
return self._compliant_frame

def _extract_compliant(self, arg: Any) -> Any:
if is_into_expr_eager(arg):
plx: EagerNamespaceAny = self.__narwhals_namespace__()
return plx.parse_into_expr(arg, str_as_lit=False)
def _parse_into_expr(self, arg: Expr | Series[Any] | _1DArray | str) -> Expr:
if isinstance(arg, str):
return col(arg)
if is_numpy_array_1d(arg):
return new_series("", arg, backend=self.implementation)._to_expr()
if is_series(arg):
return arg._to_expr()
if is_expr(arg):
return arg
raise InvalidIntoExprError.from_invalid_type(type(arg))

@property
Expand Down Expand Up @@ -2287,39 +2298,34 @@ class LazyFrame(BaseFrame[LazyFrameT]):
def _compliant(self) -> CompliantLazyFrame[Any, LazyFrameT, Self]:
return self._compliant_frame

def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, Series): # pragma: no cover
msg = "Binary operations between Series and LazyFrame are not supported."
raise TypeError(msg)
if isinstance(arg, (Expr, str)):
if isinstance(arg, Expr):
if arg._metadata.n_orderable_ops:
msg = (
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
"Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n"
"For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n"
"`'date'` which orders your data, then replace:\n\n"
" nw.col('price').cum_sum()\n\n"
" with:\n\n"
" nw.col('price').cum_sum().over(order_by='date')\n"
" ^^^^^^^^^^^^^^^^^^^^^^\n\n"
"See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/."
)
raise InvalidOperationError(msg)
if arg._metadata.is_filtration:
msg = (
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
"followed by an aggregation.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
)
raise InvalidOperationError(msg)
return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False)
def _parse_into_expr(self, arg: Expr | str) -> Expr:
if isinstance(arg, str):
return col(arg)
if is_expr(arg):
if arg._metadata.n_orderable_ops:
msg = (
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
"Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n"
"For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n"
"`'date'` which orders your data, then replace:\n\n"
" nw.col('price').cum_sum()\n\n"
" with:\n\n"
" nw.col('price').cum_sum().over(order_by='date')\n"
" ^^^^^^^^^^^^^^^^^^^^^^\n\n"
"See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/."
)
raise InvalidOperationError(msg)
if arg._metadata.is_filtration:
msg = (
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
"followed by an aggregation.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
)
raise InvalidOperationError(msg)
return arg
raise InvalidIntoExprError.from_invalid_type(type(arg))

@property
Expand Down
27 changes: 5 additions & 22 deletions narwhals/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from typing import TYPE_CHECKING, Any, Generic, TypeVar

from narwhals._expression_parsing import all_exprs_are_scalar_like
from narwhals._utils import flatten, tupleify
from narwhals._utils import tupleify
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import DataFrameT

Expand Down Expand Up @@ -72,23 +71,15 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
2 b 3 2
3 c 3 1
"""
flat_aggs = tuple(flatten(aggs))
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs)
if not all(x.is_scalar_like for x in kinds):
msg = (
"Found expression which does not aggregate.\n\n"
"All expressions passed to GroupBy.agg must aggregate.\n"
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
"but `df.group_by('a').agg(nw.col('b'))` is not."
)
raise InvalidOperationError(msg)
plx = self._df.__narwhals_namespace__()
compliant_aggs = (
*(x._to_compliant_expr(plx) for x in flat_aggs),
*(
value.alias(key)._to_compliant_expr(plx)
for key, value in named_aggs.items()
),
)
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))

def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
Expand Down Expand Up @@ -166,21 +157,13 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
|└─────┴─────┴─────┘|
└───────────────────┘
"""
flat_aggs = tuple(flatten(aggs))
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs)
if not all(x.is_scalar_like for x in kinds):
msg = (
"Found expression which does not aggregate.\n\n"
"All expressions passed to GroupBy.agg must aggregate.\n"
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
"but `df.group_by('a').agg(nw.col('b'))` is not."
)
raise InvalidOperationError(msg)
plx = self._df.__narwhals_namespace__()
compliant_aggs = (
*(x._to_compliant_expr(plx) for x in flat_aggs),
*(
value.alias(key)._to_compliant_expr(plx)
for key, value in named_aggs.items()
),
)
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
6 changes: 6 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, overload

from narwhals._expression_parsing import ExprMetadata
from narwhals._utils import (
Implementation,
Version,
Expand All @@ -20,6 +21,7 @@
from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar
from narwhals.dtypes import _validate_dtype, _validate_into_dtype
from narwhals.exceptions import ComputeError, InvalidOperationError
from narwhals.expr import Expr
from narwhals.series_cat import SeriesCatNamespace
from narwhals.series_dt import SeriesDateTimeNamespace
from narwhals.series_list import SeriesListNamespace
Expand Down Expand Up @@ -89,6 +91,10 @@ def _dataframe(self) -> type[DataFrame[Any]]:

return DataFrame

def _to_expr(self) -> Expr:
md = ExprMetadata.selector_single()
return Expr(lambda _plx: self._compliant._to_expr(), md)

def __init__(
self, series: Any, *, level: Literal["full", "lazy", "interchange"]
) -> None:
Expand Down
15 changes: 6 additions & 9 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import narwhals as nw
from narwhals import exceptions, functions as nw_f
from narwhals._exceptions import issue_warning
from narwhals._expression_parsing import is_expr
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
Implementation,
Expand Down Expand Up @@ -233,17 +234,13 @@ def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) ->
def _dataframe(self) -> type[DataFrame[Any]]:
return DataFrame

def _extract_compliant(self, arg: Any) -> Any:
def _parse_into_expr(self, arg: Expr | str) -> Expr: # type: ignore[override]
# After v1, we raise when passing order-dependent, length-changing,
# or filtration expressions to LazyFrame
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, Series): # pragma: no cover
msg = "Mixing Series with LazyFrame is not supported."
raise TypeError(msg)
if isinstance(arg, (Expr, str)):
return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False)
if isinstance(arg, str):
return col(arg)
if is_expr(arg):
return arg
raise InvalidIntoExprError.from_invalid_type(type(arg))

def collect(
Expand Down
2 changes: 1 addition & 1 deletion tests/frame/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def test_group_by_shift_raises(constructor: Constructor) -> None:
df_native = {"a": [1, 2, 3], "b": [1, 1, 2]}
df = nw.from_native(constructor(df_native))
with pytest.raises(InvalidOperationError, match="does not aggregate"):
df.group_by("b").agg(nw.col("a").shift(1))
df.group_by("b").agg(nw.col("a").abs())


def test_double_same_aggregation(
Expand Down
Loading