diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 561ff3de6d..2b0b02a3f1 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -37,6 +37,7 @@ from typing_extensions import NotRequired, Self, TypedDict from narwhals._compliant.dataframe import CompliantDataFrame + from narwhals._compliant.expr import CompliantExpr, EagerExpr from narwhals._compliant.namespace import EagerNamespace from narwhals._utils import Implementation, Version, _LimitedContext from narwhals.dtypes import DType @@ -96,6 +97,8 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def _with_native(self, series: Any) -> Self: ... def _with_version(self, version: Version) -> Self: ... + def _to_expr(self) -> CompliantExpr[Any, Self]: ... + # NOTE: `polars` @property def dtype(self) -> DType: ... @@ -243,6 +246,9 @@ def __narwhals_namespace__( self, ) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ... + def _to_expr(self) -> EagerExpr[Any, Any]: + return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] + def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def __getitem__(self, item: MultiIndexSelector[Self]) -> Self: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index b5cbd85f6b..bd5b9138c2 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -5,11 +5,10 @@ from __future__ import annotations from enum import Enum, auto -from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar from narwhals._utils import is_compliant_expr, zip_strict -from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d +from narwhals.dependencies import is_narwhals_series, is_numpy_array from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError if TYPE_CHECKING: @@ -46,13 +45,6 @@ def is_series(obj: Any) -> TypeIs[Series[Any]]: return isinstance(obj, Series) -def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]: - from narwhals.expr import Expr - from narwhals.series import Series - - return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj) - - def combine_evaluate_output_names( *exprs: CompliantExpr[CompliantFrameT, Any], ) -> EvalNames[CompliantFrameT]: @@ -584,14 +576,6 @@ def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> No raise InvalidOperationError(msg) -def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool: - # Raise if any argument in `args` isn't an aggregation or literal. - # For Series input, we don't raise (yet), we let such checks happen later, - # as this function works lazily and so can't evaluate lengths. - exprs = chain(args, kwargs.values()) - return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs) - - def apply_n_ary_operation( plx: CompliantNamespaceAny, n_ary_function: Callable[..., CompliantExprAny], diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 8da184e335..7e078951dd 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -4,6 +4,7 @@ import polars as pl +from narwhals._polars.expr import PolarsExpr from narwhals._polars.utils import ( BACKEND_VERSION, SERIES_ACCEPTS_PD_INDEX, @@ -150,6 +151,10 @@ def __init__(self, series: pl.Series, *, version: Version) -> None: self._native_series = series self._version = version + def _to_expr(self) -> PolarsExpr: + # Polars can treat Series as Expr, so just pass down `self.native`. + return PolarsExpr(self.native, version=self._version) # type: ignore[arg-type] + @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 5b2752e029..2deb3d240c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -19,8 +19,9 @@ from narwhals._expression_parsing import ( ExprKind, check_expressions_preserve_length, - is_into_expr_eager, + is_expr, is_scalar_like, + is_series, ) from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl from narwhals._utils import ( @@ -43,14 +44,14 @@ supports_arrow_c_stream, zip_strict, ) -from narwhals.dependencies import is_numpy_array_2d, is_pyarrow_table +from narwhals.dependencies import is_numpy_array_1d, is_numpy_array_2d, is_pyarrow_table from narwhals.exceptions import ( ColumnNotFoundError, InvalidIntoExprError, InvalidOperationError, PerformanceWarning, ) -from narwhals.functions import _from_dict_no_backend, _is_into_schema +from narwhals.functions import _from_dict_no_backend, _is_into_schema, col, new_series from narwhals.schema import Schema from narwhals.series import Series from narwhals.translate import to_native @@ -67,9 +68,10 @@ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame - from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny + from narwhals._compliant.typing import CompliantExprAny from narwhals._translate import IntoArrowTable from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars + from narwhals.expr import Expr from narwhals.group_by import GroupBy, LazyGroupBy from narwhals.typing import ( AsofJoinStrategy, @@ -87,6 +89,7 @@ SingleIndexSelector, SizeUnit, UniqueKeepStrategy, + _1DArray, _2DArray, ) @@ -148,18 +151,21 @@ def _flatten_and_extract( # NOTE: Strings are interpreted as column names. out_exprs = [] out_kinds = [] - for expr in flatten(exprs): - compliant_expr = self._extract_compliant(expr) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) - for alias, expr in named_exprs.items(): - compliant_expr = self._extract_compliant(expr).alias(alias) - out_exprs.append(compliant_expr) - out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False)) + ns = self.__narwhals_namespace__() + all_exprs = chain( + (self._parse_into_expr(x) for x in flatten(exprs)), + ( + self._parse_into_expr(expr).alias(alias) + for alias, expr in named_exprs.items() + ), + ) + for expr in all_exprs: + out_exprs.append(expr._to_compliant_expr(ns)) + out_kinds.append(ExprKind.from_expr(expr)) return out_exprs, out_kinds @abstractmethod - def _extract_compliant(self, arg: Any) -> Any: + def _parse_into_expr(self, arg: Any) -> Expr: raise NotImplementedError def _extract_compliant_frame(self, other: Self | Any, /) -> Any: @@ -476,10 +482,15 @@ class DataFrame(BaseFrame[DataFrameT]): def _compliant(self) -> CompliantDataFrame[Any, Any, DataFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - if is_into_expr_eager(arg): - plx: EagerNamespaceAny = self.__narwhals_namespace__() - return plx.parse_into_expr(arg, str_as_lit=False) + def _parse_into_expr(self, arg: Expr | Series[Any] | _1DArray | str) -> Expr: + if isinstance(arg, str): + return col(arg) + if is_numpy_array_1d(arg): + return new_series("", arg, backend=self.implementation)._to_expr() + if is_series(arg): + return arg._to_expr() + if is_expr(arg): + return arg raise InvalidIntoExprError.from_invalid_type(type(arg)) @property @@ -2287,39 +2298,34 @@ class LazyFrame(BaseFrame[LazyFrameT]): def _compliant(self) -> CompliantLazyFrame[Any, LazyFrameT, Self]: return self._compliant_frame - def _extract_compliant(self, arg: Any) -> Any: - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Binary operations between Series and LazyFrame are not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - if isinstance(arg, Expr): - if arg._metadata.n_orderable_ops: - msg = ( - "Order-dependent expressions are not supported for use in LazyFrame.\n\n" - "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" - "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" - "`'date'` which orders your data, then replace:\n\n" - " nw.col('price').cum_sum()\n\n" - " with:\n\n" - " nw.col('price').cum_sum().over(order_by='date')\n" - " ^^^^^^^^^^^^^^^^^^^^^^\n\n" - "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." - ) - raise InvalidOperationError(msg) - if arg._metadata.is_filtration: - msg = ( - "Length-changing expressions are not supported for use in LazyFrame, unless\n" - "followed by an aggregation.\n\n" - "Hints:\n" - "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" - "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" - " use `lf.select(nw.col('a').drop_nulls().sum())\n" - ) - raise InvalidOperationError(msg) - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) + def _parse_into_expr(self, arg: Expr | str) -> Expr: + if isinstance(arg, str): + return col(arg) + if is_expr(arg): + if arg._metadata.n_orderable_ops: + msg = ( + "Order-dependent expressions are not supported for use in LazyFrame.\n\n" + "Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n" + "For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n" + "`'date'` which orders your data, then replace:\n\n" + " nw.col('price').cum_sum()\n\n" + " with:\n\n" + " nw.col('price').cum_sum().over(order_by='date')\n" + " ^^^^^^^^^^^^^^^^^^^^^^\n\n" + "See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/." + ) + raise InvalidOperationError(msg) + if arg._metadata.is_filtration: + msg = ( + "Length-changing expressions are not supported for use in LazyFrame, unless\n" + "followed by an aggregation.\n\n" + "Hints:\n" + "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" + "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" + " use `lf.select(nw.col('a').drop_nulls().sum())\n" + ) + raise InvalidOperationError(msg) + return arg raise InvalidIntoExprError.from_invalid_type(type(arg)) @property diff --git a/narwhals/group_by.py b/narwhals/group_by.py index c469ac921e..1bdd8e9ec2 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar -from narwhals._expression_parsing import all_exprs_are_scalar_like -from narwhals._utils import flatten, tupleify +from narwhals._utils import tupleify from narwhals.exceptions import InvalidOperationError from narwhals.typing import DataFrameT @@ -72,8 +71,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: 2 b 3 2 3 c 3 1 """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(x.is_scalar_like for x in kinds): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -81,14 +80,6 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]: @@ -166,8 +157,8 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: |└─────┴─────┴─────┘| └───────────────────┘ """ - flat_aggs = tuple(flatten(aggs)) - if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs): + compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs) + if not all(x.is_scalar_like for x in kinds): msg = ( "Found expression which does not aggregate.\n\n" "All expressions passed to GroupBy.agg must aggregate.\n" @@ -175,12 +166,4 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT: "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - plx = self._df.__narwhals_namespace__() - compliant_aggs = ( - *(x._to_compliant_expr(plx) for x in flat_aggs), - *( - value.alias(key)._to_compliant_expr(plx) - for key, value in named_aggs.items() - ), - ) return self._df._with_compliant(self._grouped.agg(*compliant_aggs)) diff --git a/narwhals/series.py b/narwhals/series.py index c28ff05464..f57a0f1c62 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -4,6 +4,7 @@ from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, overload +from narwhals._expression_parsing import ExprMetadata from narwhals._utils import ( Implementation, Version, @@ -20,6 +21,7 @@ from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar from narwhals.dtypes import _validate_dtype, _validate_into_dtype from narwhals.exceptions import ComputeError, InvalidOperationError +from narwhals.expr import Expr from narwhals.series_cat import SeriesCatNamespace from narwhals.series_dt import SeriesDateTimeNamespace from narwhals.series_list import SeriesListNamespace @@ -89,6 +91,10 @@ def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame + def _to_expr(self) -> Expr: + md = ExprMetadata.selector_single() + return Expr(lambda _plx: self._compliant._to_expr(), md) + def __init__( self, series: Any, *, level: Literal["full", "lazy", "interchange"] ) -> None: diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index f9780fa4a9..640b76ac4d 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -6,6 +6,7 @@ import narwhals as nw from narwhals import exceptions, functions as nw_f from narwhals._exceptions import issue_warning +from narwhals._expression_parsing import is_expr from narwhals._typing_compat import TypeVar, assert_never from narwhals._utils import ( Implementation, @@ -233,17 +234,13 @@ def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) -> def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame - def _extract_compliant(self, arg: Any) -> Any: + def _parse_into_expr(self, arg: Expr | str) -> Expr: # type: ignore[override] # After v1, we raise when passing order-dependent, length-changing, # or filtration expressions to LazyFrame - from narwhals.expr import Expr - from narwhals.series import Series - - if isinstance(arg, Series): # pragma: no cover - msg = "Mixing Series with LazyFrame is not supported." - raise TypeError(msg) - if isinstance(arg, (Expr, str)): - return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False) + if isinstance(arg, str): + return col(arg) + if is_expr(arg): + return arg raise InvalidIntoExprError.from_invalid_type(type(arg)) def collect( diff --git a/tests/frame/group_by_test.py b/tests/frame/group_by_test.py index 183811dddf..1b5a359845 100644 --- a/tests/frame/group_by_test.py +++ b/tests/frame/group_by_test.py @@ -364,7 +364,7 @@ def test_group_by_shift_raises(constructor: Constructor) -> None: df_native = {"a": [1, 2, 3], "b": [1, 1, 2]} df = nw.from_native(constructor(df_native)) with pytest.raises(InvalidOperationError, match="does not aggregate"): - df.group_by("b").agg(nw.col("a").shift(1)) + df.group_by("b").agg(nw.col("a").abs()) def test_double_same_aggregation(