diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 148259d8f0..a7a28dad7f 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -29,7 +29,7 @@ from narwhals.expr import Expr from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr - from narwhals.typing import CompliantFrameT_contra + from narwhals.typing import CompliantFrameT from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantNamespace from narwhals.typing import CompliantSeries @@ -51,8 +51,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]: def evaluate_into_expr( - df: CompliantFrameT_contra, - expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + df: CompliantFrameT, + expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> Sequence[CompliantSeriesT_co]: """Return list of raw columns. @@ -72,9 +72,9 @@ def evaluate_into_expr( def evaluate_into_exprs( - df: CompliantFrameT_contra, + df: CompliantFrameT, /, - *exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + *exprs: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> list[CompliantSeriesT_co]: """Evaluate each expr into Series.""" return [ @@ -86,8 +86,8 @@ def evaluate_into_exprs( @overload def maybe_evaluate_expr( - df: CompliantFrameT_contra, - expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], + df: CompliantFrameT, + expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], ) -> CompliantSeriesT_co: ... @@ -257,15 +257,15 @@ def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool: def combine_evaluate_output_names( - *exprs: CompliantExpr[CompliantFrameT_contra, Any], -) -> Callable[[CompliantFrameT_contra], Sequence[str]]: + *exprs: CompliantExpr[CompliantFrameT, Any], +) -> Callable[[CompliantFrameT], Sequence[str]]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. if not is_compliant_expr(exprs[0]): # pragma: no cover msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." raise AssertionError(msg) - def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]: + def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] return evaluate_output_names @@ -286,11 +286,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co], + plx: CompliantNamespace[CompliantFrameT, CompliantSeriesT_co], other: Any, *, str_as_lit: bool, -) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object: +) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: diff --git a/narwhals/_protocols.py b/narwhals/_protocols.py new file mode 100644 index 0000000000..d796b589dc --- /dev/null +++ b/narwhals/_protocols.py @@ -0,0 +1,312 @@ +"""Refinements/extensions on `Compliant*` protocols. + +Mentioned in https://github.com/narwhals-dev/narwhals/issues/2044#issuecomment-2668381264 + +Notes: +- `Reuse*` are shared `pyarrow` & `pandas` +- No strong feelings on the name, borrowed from `reuse_series_(namespace_?)implementation` +- https://github.com/narwhals-dev/narwhals/blob/70220169458599013a06fbf3024effcb860d62ed/narwhals/_expression_parsing.py#L112 +""" + +from __future__ import annotations + +from itertools import chain +from operator import methodcaller +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Protocol +from typing import Sequence +from typing import TypeVar + +from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals.typing import CompliantDataFrame +from narwhals.typing import CompliantExpr +from narwhals.typing import CompliantNamespace +from narwhals.typing import CompliantSeries + +if TYPE_CHECKING: + from types import ModuleType + from typing import Iterable + from typing import Iterator + + from typing_extensions import Self + from typing_extensions import TypeIs + + from narwhals.dtypes import DType + from narwhals.typing import NativeSeries + from narwhals.typing import _1DArray + from narwhals.utils import Implementation + from narwhals.utils import Version + + T = TypeVar("T") + +NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True) +ReuseSeriesT = TypeVar("ReuseSeriesT", bound="ReuseSeries[Any]") + +ReuseDataFrameT = TypeVar("ReuseDataFrameT", bound="ReuseDataFrame[Any]") + + +class ReuseDataFrame(CompliantDataFrame, Protocol[ReuseSeriesT]): + def _evaluate_into_expr( + self: ReuseDataFrameT, expr: ReuseExpr[ReuseDataFrameT, ReuseSeriesT], / + ) -> Sequence[ReuseSeriesT]: + _, aliases = evaluate_output_names_and_aliases(expr, self, []) + result = expr(self) + if list(aliases) != [s.name for s in result]: + msg = f"Safety assertion failed, expected {aliases}, got {result}" + raise AssertionError(msg) + return result + + def _evaluate_into_exprs( + self, *exprs: ReuseExpr[Self, ReuseSeriesT] + ) -> Sequence[ReuseSeriesT]: + return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) + + # NOTE: `mypy` is **very** fragile here in what is permitted + # DON'T CHANGE UNLESS IT SOLVES ANOTHER ISSUE + def _maybe_evaluate_expr( + self, expr: ReuseExpr[ReuseDataFrame[ReuseSeriesT], ReuseSeriesT] | T, / + ) -> ReuseSeriesT | T: + if is_reuse_expr(expr): + result: Sequence[ReuseSeriesT] = expr(self) + if len(result) > 1: + msg = ( + "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) " + "are not supported in this context" + ) + raise ValueError(msg) + return result[0] + return expr + + +# NOTE: DON'T CHANGE THIS EITHER +def is_reuse_expr( + obj: ReuseExpr[Any, ReuseSeriesT] | Any, +) -> TypeIs[ReuseExpr[Any, ReuseSeriesT]]: + return hasattr(obj, "__narwhals_expr__") + + +class ReuseExpr( + CompliantExpr[ReuseDataFrameT, ReuseSeriesT], + Protocol[ReuseDataFrameT, ReuseSeriesT], +): + _call: Callable[[ReuseDataFrameT], Sequence[ReuseSeriesT]] + _depth: int + _function_name: str + _evaluate_output_names: Any + _alias_output_names: Any + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + _kwargs: dict[str, Any] + + def __init__( + self: Self, + call: Callable[[ReuseDataFrameT], Sequence[ReuseSeriesT]], + *, + depth: int, + function_name: str, + evaluate_output_names: Callable[[ReuseDataFrameT], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + version: Version, + kwargs: dict[str, Any], + ) -> None: ... + + def __call__(self, df: ReuseDataFrameT) -> Sequence[ReuseSeriesT]: + return self._call(df) + + def __narwhals_namespace__(self) -> ReuseNamespace[ReuseDataFrameT, ReuseSeriesT]: ... + + def _reuse_series_implementation( + self: ReuseExpr[ReuseDataFrameT, ReuseSeriesT], + attr: str, + *, + returns_scalar: bool = False, + **expressifiable_args: Any, + ) -> ReuseExpr[ReuseDataFrameT, ReuseSeriesT]: + from narwhals._expression_parsing import evaluate_output_names_and_aliases + + plx = self.__narwhals_namespace__() + from_scalar = plx._create_series_from_scalar + + # NOTE: Ideally this would be implemented differently for `pandas` and `pyarrow` + # - It wouldn't make sense to check the implementation for each call + # - Just copying over from the function + def func(df: ReuseDataFrameT, /) -> Sequence[ReuseSeriesT]: + _kwargs = { + arg_name: df._maybe_evaluate_expr(arg_value) + for arg_name, arg_value in expressifiable_args.items() + } + # For PyArrow.Series, we return Python Scalars (like Polars does) instead of PyArrow Scalars. + # However, when working with expressions, we keep everything PyArrow-native. + extra_kwargs = ( + {"_return_py_scalar": False} + if returns_scalar and self._implementation.is_pyarrow() + else {} + ) + method = methodcaller(attr, **extra_kwargs, **_kwargs) + out: Sequence[ReuseSeriesT] = [ + from_scalar(method(series), reference_series=series) + if returns_scalar + else method(series) + for series in self(df) + ] + _, aliases = evaluate_output_names_and_aliases(self, df, []) + if [s.name for s in out] != list(aliases): # pragma: no cover + msg = ( + f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n" + f"Expression aliases: {aliases}\n" + f"Series names: {[s.name for s in out]}" + ) + raise AssertionError(msg) + return out + + return plx._create_expr_from_callable( + func, + depth=self._depth + 1, + function_name=f"{self._function_name}->{attr}", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + kwargs={**self._kwargs, **expressifiable_args}, + ) + + def _reuse_series_namespace_implementation( + self: ReuseExpr[ReuseDataFrameT, ReuseSeriesT], + series_namespace: str, + attr: str, + **kwargs: Any, + ) -> ReuseExpr[ReuseDataFrameT, ReuseSeriesT]: + plx = self.__narwhals_namespace__() + return plx._create_expr_from_callable( + lambda df: [ + getattr(getattr(series, series_namespace), attr)(**kwargs) + for series in self(df) + ], + depth=self._depth + 1, + function_name=f"{self._function_name}->{series_namespace}.{attr}", + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, + kwargs={**self._kwargs, **kwargs}, + ) + + +class ReuseSeries(CompliantSeries, Protocol[NativeSeriesT_co]): # type: ignore[misc] + _name: str + _native_series: NativeSeriesT_co + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def __init__( + self, + native_series: NativeSeriesT_co, + *, + implementation: Implementation, + backend_version: tuple[int, ...], + version: Version, + ) -> None: ... + + def __array__(self: Self, dtype: Any, copy: bool | None) -> _1DArray: ... + def __contains__(self: Self, other: Any) -> bool: ... + def __iter__(self: Self) -> Iterator[Any]: ... + def __len__(self: Self) -> int: + return len(self._native_series) + + def __narwhals_series__(self: Self) -> Self: + return self + + # NOTE: Each side adds an `AssertionError` guard first + # Would probably make more sense to require a ClassVar, either: + # - defining the set of permitted impls + # - setting an unbound method, which also gives the error message + # - E.g. `PandasLikeSeries._ensure = Implementation.ensure_pandas_like` + # - That gets called w/ `ReuseSeries._ensure(ReuseSeries._implementation)` + def __native_namespace__(self: Self) -> ModuleType: + return self._implementation.to_native_namespace() + + @classmethod + def _from_iterable( + cls: type[Self], + data: Iterable[Any], + name: str, + *, + implementation: Implementation, + backend_version: tuple[int, ...], + version: Version, + ) -> Self: ... + def _from_native_series(self: Self, series: NativeSeriesT_co | Any) -> Self: ... + @property + def dtype(self: Self) -> DType: ... + @property + def name(self: Self) -> str: + return self._name + + +class ReuseNamespace( + CompliantNamespace[ReuseDataFrameT, ReuseSeriesT], + Protocol[ReuseDataFrameT, ReuseSeriesT], +): + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def __init__( + self, + implementation: Implementation, + backend_version: tuple[int, ...], + version: Version, + ) -> None: ... + + def __narwhals_expr__( + self, + ) -> Callable[..., ReuseExpr[ReuseDataFrameT, ReuseSeriesT]]: ... + # Both do very similar things: + # - `_pandas_like.utils.create_compliant_series` + # - `_arrow.series.ArrowSeries(native_series=pa.chunked_array([value]))` + def _create_compliant_series(self, value: Any) -> ReuseSeriesT: ... + + # NOTE: Fully spec'd + def _create_expr_from_callable( + self, + func: Callable[[ReuseDataFrameT], Sequence[ReuseSeriesT]], + *, + depth: int, + function_name: str, + evaluate_output_names: Callable[[ReuseDataFrameT], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, + kwargs: dict[str, Any], + ) -> ReuseExpr[ReuseDataFrameT, ReuseSeriesT]: + return self.__narwhals_expr__()( + func, + depth=depth, + function_name=function_name, + evaluate_output_names=evaluate_output_names, + alias_output_names=alias_output_names, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + kwargs=kwargs, + ) + + # NOTE: Fully spec'd + def _create_expr_from_series( + self, series: ReuseSeriesT + ) -> ReuseExpr[ReuseDataFrameT, ReuseSeriesT]: + return self.__narwhals_expr__()( + lambda _df: [series], + depth=0, + function_name="series", + evaluate_output_names=lambda _df: [series.name], + alias_output_names=None, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + kwargs={}, + ) + + def _create_series_from_scalar( + self, value: Any, *, reference_series: ReuseSeriesT + ) -> ReuseSeriesT: ... diff --git a/narwhals/typing.py b/narwhals/typing.py index f1b62179e6..3d1a55aa07 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Generic from typing import Literal from typing import Protocol from typing import Sequence @@ -74,30 +73,28 @@ def aggregate(self, *exprs: Any) -> Self: # (so, no broadcasting is necessary). -CompliantFrameT_contra = TypeVar( - "CompliantFrameT_contra", - bound="CompliantDataFrame | CompliantLazyFrame", - contravariant=True, +CompliantFrameT = TypeVar( + "CompliantFrameT", bound="CompliantDataFrame | CompliantLazyFrame" ) CompliantSeriesT_co = TypeVar( "CompliantSeriesT_co", bound=CompliantSeries, covariant=True ) -class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): +class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]): _implementation: Implementation _backend_version: tuple[int, ...] _version: Version - _evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]] + _evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]] _alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None _depth: int _function_name: str - def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ... + def __call__(self, df: CompliantFrameT) -> Sequence[CompliantSeriesT_co]: ... def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, - ) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantNamespace[CompliantFrameT, CompliantSeriesT_co]: ... def is_null(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: DType) -> Self: ... @@ -119,13 +116,13 @@ def broadcast( ) -> Self: ... -class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): +class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesT_co]): def col( self, *column_names: str - ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ... def lit( self, value: Any, dtype: DType | None - ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... + ) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ... class SupportsNativeNamespace(Protocol): @@ -325,7 +322,7 @@ class DTypes: # This one needs to be in TYPE_CHECKING to pass on 3.9, # and can only be defined after CompliantExpr has been defined IntoCompliantExpr: TypeAlias = ( - CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co + CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | CompliantSeriesT_co ) diff --git a/narwhals/utils.py b/narwhals/utils.py index 0babd5440a..f628ed6f03 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -53,7 +53,7 @@ from narwhals.series import Series from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr - from narwhals.typing import CompliantFrameT_contra + from narwhals.typing import CompliantFrameT from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantSeries from narwhals.typing import CompliantSeriesT_co @@ -411,6 +411,20 @@ def is_ibis(self: Self) -> bool: """ return self is Implementation.IBIS # pragma: no cover + def ensure_pandas_like(self: Self) -> None: + if self.is_pandas_like(): + return + pandas_like = {Implementation.PANDAS, Implementation.CUDF, Implementation.MODIN} + msg = f"Expected pandas-like implementation ({pandas_like}), found {self}" + raise TypeError(msg) + + # NOTE: Not sure why this differs from `pandas_like` + def ensure_pyarrow(self: Self) -> None: + if self.is_pyarrow(): + return + msg = f"Expected pyarrow, got: {type(self)}" # pragma: no cover + raise AssertionError(msg) + MIN_VERSIONS: dict[Implementation, tuple[int, ...]] = { Implementation.PANDAS: (0, 25, 3), @@ -1321,8 +1335,8 @@ def is_compliant_series(obj: Any) -> TypeIs[CompliantSeries]: def is_compliant_expr( - obj: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | Any, -) -> TypeIs[CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]]: + obj: CompliantExpr[CompliantFrameT, CompliantSeriesT_co] | Any, +) -> TypeIs[CompliantExpr[CompliantFrameT, CompliantSeriesT_co]]: return hasattr(obj, "__narwhals_expr__") diff --git a/pyproject.toml b/pyproject.toml index dd4244f426..23418264bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -247,6 +247,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ "*._expression_parsing", + "*._protocols", "*.utils", "*._pandas_like.*", "*._ibis.*",