diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py new file mode 100644 index 0000000000..867d16d397 --- /dev/null +++ b/narwhals/_plan/_guards.py @@ -0,0 +1,103 @@ +"""Common type guards, mostly with inline imports.""" + +from __future__ import annotations + +import datetime as dt +from decimal import Decimal +from typing import TYPE_CHECKING, Any, TypeVar + +from narwhals._utils import _hasattr_static + +if TYPE_CHECKING: + from typing_extensions import TypeIs + + from narwhals._plan import expr + from narwhals._plan.dummy import Expr, Series + from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.typing import NativeSeriesT, Seq + from narwhals.typing import NonNestedLiteral + + T = TypeVar("T") + +_NON_NESTED_LITERAL_TPS = ( + int, + float, + str, + dt.date, + dt.time, + dt.timedelta, + bytes, + Decimal, +) + + +def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import dummy + + return dummy + + +def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202 + from narwhals._plan import expr + + return expr + + +def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: + return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) + + +def is_expr(obj: Any) -> TypeIs[Expr]: + return isinstance(obj, _dummy().Expr) + + +def is_column(obj: Any) -> TypeIs[Expr]: + """Indicate if the given object is a basic/unaliased column.""" + return is_expr(obj) and obj.meta.is_column() + + +def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: + return isinstance(obj, _dummy().Series) + + +def is_compliant_series( + obj: CompliantSeries[NativeSeriesT] | Any, +) -> TypeIs[CompliantSeries[NativeSeriesT]]: + return _hasattr_static(obj, "__narwhals_series__") + + +def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: + return isinstance(obj, (str, bytes, _dummy().Series)) or is_compliant_series(obj) + + +def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]: + return isinstance(obj, _expr().WindowExpr) + + +def is_function_expr(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: + return isinstance(obj, _expr().FunctionExpr) + + +def is_binary_expr(obj: Any) -> TypeIs[expr.BinaryExpr]: + return isinstance(obj, _expr().BinaryExpr) + + +def is_agg_expr(obj: Any) -> TypeIs[expr.AggExpr]: + return isinstance(obj, _expr().AggExpr) + + +def is_aggregation(obj: Any) -> TypeIs[expr.AggExpr | expr.FunctionExpr[Any]]: + """Superset of `ExprIR.is_scalar`, excludes literals & len.""" + return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) + + +def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: + return isinstance(obj, _expr().Literal) + + +def is_horizontal_reduction(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]: + return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() + + +def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: + return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py new file mode 100644 index 0000000000..0abe0739b6 --- /dev/null +++ b/narwhals/_plan/_immutable.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any, Callable + + from typing_extensions import Never, Self, dataclass_transform + +else: + # https://docs.python.org/3/library/typing.html#typing.dataclass_transform + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, + ) -> Callable[[T], T]: + def decorator(cls_or_fn: T) -> T: + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + + return decorator + + +T = TypeVar("T") +_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" + + +@dataclass_transform(kw_only_default=True, frozen_default=True) +class Immutable: + """A poor man's frozen dataclass. + + - Keyword-only constructor (IDE supported) + - Manual `__slots__` required + - Compatible with [`copy.replace`] + - No handling for default arguments + + [`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace + """ + + __slots__ = (_IMMUTABLE_HASH_NAME,) + __immutable_hash_value__: int + + @property + def __immutable_keys__(self) -> Iterator[str]: + slots: tuple[str, ...] = self.__slots__ + for name in slots: + if name != _IMMUTABLE_HASH_NAME: + yield name + + @property + def __immutable_values__(self) -> Iterator[Any]: + for name in self.__immutable_keys__: + yield getattr(self, name) + + @property + def __immutable_items__(self) -> Iterator[tuple[str, Any]]: + for name in self.__immutable_keys__: + yield name, getattr(self, name) + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *self.__immutable_values__)) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ + + def __setattr__(self, name: str, value: Never) -> Never: + msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." + raise AttributeError(msg) + + def __replace__(self, **changes: Any) -> Self: + """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 + if len(changes) == 1: + # The most common case is a single field replacement. + # Iff that field happens to be equal, we can noop, preserving the current object's hash. + name, value_changed = next(iter(changes.items())) + if getattr(self, name) == value_changed: + return self + changes = dict(self.__immutable_items__, **changes) + else: + for name, value_current in self.__immutable_items__: + if name not in changes or value_current == changes[name]: + changes[name] = value_current + return type(self)(**changes) + + def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + if cls.__slots__: + ... + else: + cls.__slots__ = () + + def __hash__(self) -> int: + return self.__immutable_hash__ + + def __eq__(self, other: object) -> bool: + if self is other: + return True + if type(self) is not type(other): + return False + return all( + getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ + ) + + def __str__(self) -> str: + fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) + return f"{type(self).__name__}({fields})" + + def __init__(self, **kwds: Any) -> None: + required: set[str] = set(self.__immutable_keys__) + if not required and not kwds: + # NOTE: Fastpath for empty slots + ... + elif required == set(kwds): + for name, value in kwds.items(): + object.__setattr__(self, name, value) + elif missing := required.difference(kwds): + msg = ( + f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" + f"but missing values for {sorted(missing)!r}" + ) + raise TypeError(msg) + else: + extra = set(kwds).difference(required) + msg = ( + f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" + f"but got unknown arguments {sorted(extra)!r}" + ) + raise TypeError(msg) + + +def _field_str(name: str, value: Any) -> str: + if isinstance(value, tuple): + inner = ", ".join(f"{v}" for v in value) + return f"{name}=[{inner}]" + if isinstance(value, str): + return f"{name}={value!r}" + return f"{name}={value}" diff --git a/narwhals/_plan/aggregation.py b/narwhals/_plan/aggregation.py index ea25e82ad1..b1f47ca1d7 100644 --- a/narwhals/_plan/aggregation.py +++ b/narwhals/_plan/aggregation.py @@ -2,19 +2,16 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import ExprIR, _pascal_to_snake_case, replace +from narwhals._plan.common import ExprIR, _pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: from collections.abc import Iterator - from typing_extensions import Self - - from narwhals._plan.typing import MapIR from narwhals.typing import RollingInterpolationMethod -class AggExpr(ExprIR): +class AggExpr(ExprIR, child=("expr",)): __slots__ = ("expr",) expr: ExprIR @@ -25,50 +22,31 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return replace(self, expr=expr) - def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: raise agg_scalar_error(self, expr) super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] +# fmt: off class Count(AggExpr): ... - - class Max(AggExpr): ... - - class Mean(AggExpr): ... - - class Median(AggExpr): ... - - class Min(AggExpr): ... - - class NUnique(AggExpr): ... - - +class Sum(AggExpr): ... +class OrderableAggExpr(AggExpr): ... +class First(OrderableAggExpr): ... +class Last(OrderableAggExpr): ... +class ArgMin(OrderableAggExpr): ... +class ArgMax(OrderableAggExpr): ... +# fmt: on class Quantile(AggExpr): __slots__ = (*AggExpr.__slots__, "interpolation", "quantile") - quantile: float interpolation: RollingInterpolationMethod @@ -78,24 +56,6 @@ class Std(AggExpr): ddof: int -class Sum(AggExpr): ... - - class Var(AggExpr): __slots__ = (*AggExpr.__slots__, "ddof") ddof: int - - -class OrderableAggExpr(AggExpr): ... - - -class First(OrderableAggExpr): ... - - -class Last(OrderableAggExpr): ... - - -class ArgMin(OrderableAggExpr): ... - - -class ArgMax(OrderableAggExpr): ... diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index cbc5f600b4..fc61e69acc 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -1,34 +1,34 @@ from __future__ import annotations -import typing as t +from typing import TYPE_CHECKING, Any, Literal, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.common import ExprIR -from narwhals._plan.protocols import EagerDataFrame +from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._utils import Version +from narwhals.schema import Schema -if t.TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence from typing_extensions import Self from narwhals._arrow.typing import ChunkedArrayAny - from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar + from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar from narwhals._plan.arrow.namespace import ArrowNamespace from narwhals._plan.common import ExprIR, NamedIR - from narwhals._plan.dummy import DataFrame + from narwhals._plan.dummy import DataFrame as NwDataFrame from narwhals._plan.options import SortMultipleOptions from narwhals._plan.typing import Seq from narwhals.dtypes import DType - from narwhals.schema import Schema + from narwhals.typing import IntoSchema -class ArrowDataFrame(EagerDataFrame[ArrowSeries, "pa.Table", "ChunkedArrayAny"]): +class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._plan.arrow.namespace import ArrowNamespace @@ -49,47 +49,37 @@ def schema(self) -> dict[str, DType]: def __len__(self) -> int: return self.native.num_rows - def to_narwhals(self) -> DataFrame[pa.Table, ChunkedArrayAny]: + def to_narwhals(self) -> NwDataFrame[pa.Table, ChunkedArrayAny]: from narwhals._plan.dummy import DataFrame return DataFrame[pa.Table, "ChunkedArrayAny"]._from_compliant(self) @classmethod def from_dict( - cls, - data: t.Mapping[str, t.Any], - /, - *, - schema: t.Mapping[str, DType] | Schema | None = None, + cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: - from narwhals.schema import Schema - pa_schema = Schema(schema).to_arrow() if schema is not None else schema native = pa.Table.from_pydict(data, schema=pa_schema) return cls.from_native(native, version=Version.MAIN) - def iter_columns(self) -> t.Iterator[ArrowSeries]: + def iter_columns(self) -> Iterator[Series]: for name, series in zip(self.columns, self.native.itercolumns()): - yield ArrowSeries.from_native(series, name, version=self.version) - - @t.overload - def to_dict(self, *, as_series: t.Literal[True]) -> dict[str, ArrowSeries]: ... - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - @t.overload - def to_dict( - self, *, as_series: bool - ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: ... - def to_dict( - self, *, as_series: bool - ) -> dict[str, ArrowSeries] | dict[str, list[t.Any]]: + yield Series.from_native(series, name, version=self.version) + + @overload + def to_dict(self, *, as_series: Literal[True]) -> dict[str, Series]: ... + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload + def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: ... + def to_dict(self, *, as_series: bool) -> dict[str, Series] | dict[str, list[Any]]: it = self.iter_columns() if as_series: return {ser.name: ser for ser in it} return {ser.name: ser.to_list() for ser in it} - def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]: - ns = self.__narwhals_namespace__() + def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[Series]: + ns = namespace(self) from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) @@ -101,16 +91,16 @@ def sort(self, by: Seq[NamedIR], options: SortMultipleOptions) -> Self: def with_row_index(self, name: str) -> Self: return self._with_native(self.native.add_column(0, name, fn.int_range(len(self)))) - def get_column(self, name: str) -> ArrowSeries: + def get_column(self, name: str) -> Series: chunked = self.native.column(name) - return ArrowSeries.from_native(chunked, name, version=self.version) + return Series.from_native(chunked, name, version=self.version) def drop(self, columns: Sequence[str]) -> Self: to_drop = list(columns) return self._with_native(self.native.drop(to_drop)) # NOTE: Use instead of `with_columns` for trivial cases - def _with_columns(self, exprs: Iterable[ArrowExpr | ArrowScalar], /) -> Self: + def _with_columns(self, exprs: Iterable[Expr | Scalar], /) -> Self: native = self.native columns = self.columns height = len(self) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index ef6e0164ed..d8a163d120 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -7,11 +7,10 @@ from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.functions import lit -from narwhals._plan.arrow.series import ArrowSeries -from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co -from narwhals._plan.common import ExprIR, NamedIR, into_dtype -from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch +from narwhals._plan.arrow.series import ArrowSeries as Series +from narwhals._plan.arrow.typing import ChunkedOrScalarAny, NativeScalar, StoresNativeT_co +from narwhals._plan.common import ExprIR, NamedIR +from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, Version, @@ -24,7 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny, Incomplete from narwhals._plan import boolean, expr @@ -44,28 +43,29 @@ Sum, Var, ) - from narwhals._plan.arrow.dataframe import ArrowDataFrame + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame from narwhals._plan.arrow.namespace import ArrowNamespace - from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull + from narwhals._plan.boolean import All, IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.expr import ( AnonymousExpr, BinaryExpr, FunctionExpr, OrderedWindowExpr, RollingExpr, - Ternary, + TernaryExpr, WindowExpr, ) from narwhals._plan.functions import FillNull, Pow from narwhals.typing import Into1DArray, IntoDType, PythonLiteral + Expr: TypeAlias = "ArrowExpr" + Scalar: TypeAlias = "ArrowScalar" + BACKEND_VERSION = Implementation.PYARROW._backend_version() -class _ArrowDispatch( - ExprDispatch["ArrowDataFrame", StoresNativeT_co, "ArrowNamespace"], Protocol -): +class _ArrowDispatch(ExprDispatch["Frame", StoresNativeT_co, "ArrowNamespace"], Protocol): """Common to `Expr`, `Scalar` + their dependencies.""" def __narwhals_namespace__(self) -> ArrowNamespace: @@ -74,94 +74,84 @@ def __narwhals_namespace__(self) -> ArrowNamespace: return ArrowNamespace(self.version) def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... - def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> StoresNativeT_co: + def cast(self, node: expr.Cast, frame: Frame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._dispatch(node.expr, frame, name).native + native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) - def pow( - self, node: FunctionExpr[Pow], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def pow(self, node: FunctionExpr[Pow], frame: Frame, name: str) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) - base_ = self._dispatch(base, frame, "base").native - exponent_ = self._dispatch(exponent, frame, "exponent").native + base_ = base.dispatch(self, frame, "base").native + exponent_ = exponent.dispatch(self, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) def fill_null( - self, node: FunctionExpr[FillNull], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[FillNull], frame: Frame, name: str ) -> StoresNativeT_co: expr, value = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - value_ = self._dispatch(value, frame, "value").native + native = expr.dispatch(self, frame, name).native + value_ = value.dispatch(self, frame, "value").native return self._with_native(pc.fill_null(native, value_), name) def is_between( - self, node: FunctionExpr[IsBetween], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsBetween], frame: Frame, name: str ) -> StoresNativeT_co: expr, lower_bound, upper_bound = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - lower = self._dispatch(lower_bound, frame, "lower").native - upper = self._dispatch(upper_bound, frame, "upper").native + native = expr.dispatch(self, frame, name).native + lower = lower_bound.dispatch(self, frame, "lower").native + upper = upper_bound.dispatch(self, frame, "upper").native result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) def _unary_function( self, fn_native: Callable[[Any], Any], / - ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], StoresNativeT_co]: - def func( - node: FunctionExpr[Any], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: - native = self._dispatch(node.input[0], frame, name).native + ) -> Callable[[FunctionExpr[Any], Frame, str], StoresNativeT_co]: + def func(node: FunctionExpr[Any], frame: Frame, name: str) -> StoresNativeT_co: + native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn_native(native), name) return func - def not_( - self, node: FunctionExpr[boolean.Not], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def not_(self, node: FunctionExpr[Not], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(pc.invert)(node, frame, name) - def all( - self, node: FunctionExpr[boolean.All], frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def all(self, node: FunctionExpr[All], frame: Frame, name: str) -> StoresNativeT_co: return self._unary_function(fn.all_)(node, frame, name) def any( - self, node: FunctionExpr[boolean.Any], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[boolean.Any], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.any_)(node, frame, name) def is_finite( - self, node: FunctionExpr[IsFinite], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsFinite], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_finite)(node, frame, name) def is_nan( - self, node: FunctionExpr[IsNan], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsNan], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_nan)(node, frame, name) def is_null( - self, node: FunctionExpr[IsNull], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[IsNull], frame: Frame, name: str ) -> StoresNativeT_co: return self._unary_function(fn.is_null)(node, frame, name) - def binary_expr( - self, node: BinaryExpr, frame: ArrowDataFrame, name: str - ) -> StoresNativeT_co: + def binary_expr(self, node: BinaryExpr, frame: Frame, name: str) -> StoresNativeT_co: lhs, rhs = ( - self._dispatch(node.left, frame, name), - self._dispatch(node.right, frame, name), + node.left.dispatch(self, frame, name), + node.right.dispatch(self, frame, name), ) result = fn.binary(lhs.native, node.op.__class__, rhs.native) return self._with_native(result, name) def ternary_expr( - self, node: Ternary, frame: ArrowDataFrame, name: str + self, node: TernaryExpr, frame: Frame, name: str ) -> StoresNativeT_co: - when = self._dispatch(node.predicate, frame, name) - then = self._dispatch(node.truthy, frame, name) - otherwise = self._dispatch(node.falsy, frame, name) + when = node.predicate.dispatch(self, frame, name) + then = node.truthy.dispatch(self, frame, name) + otherwise = node.falsy.dispatch(self, frame, name) result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) @@ -169,9 +159,9 @@ def ternary_expr( class ArrowExpr( # type: ignore[misc] _ArrowDispatch["ArrowExpr | ArrowScalar"], _StoresNative["ChunkedArrayAny"], - EagerExpr["ArrowDataFrame", ArrowSeries], + EagerExpr["Frame", Series], ): - _evaluated: ArrowSeries + _evaluated: Series _version: Version @property @@ -179,7 +169,7 @@ def name(self) -> str: return self._evaluated.name @classmethod - def from_series(cls, series: ArrowSeries, /) -> Self: + def from_series(cls, series: Series, /) -> Self: obj = cls.__new__(cls) obj._evaluated = series obj._version = series.version @@ -189,26 +179,20 @@ def from_series(cls, series: ArrowSeries, /) -> Self: def from_native( cls, native: ChunkedArrayAny, name: str = "", /, version: Version = Version.MAIN ) -> Self: - return cls.from_series(ArrowSeries.from_native(native, name, version=version)) + return cls.from_series(Series.from_native(native, name, version=version)) @overload def _with_native(self, result: ChunkedArrayAny, name: str, /) -> Self: ... @overload - def _with_native(self, result: NativeScalar, name: str, /) -> ArrowScalar: ... + def _with_native(self, result: NativeScalar, name: str, /) -> Scalar: ... @overload - def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str, / - ) -> ArrowScalar | Self: ... - def _with_native( - self, result: ChunkedArrayAny | NativeScalar, name: str, / - ) -> ArrowScalar | Self: + def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: ... + def _with_native(self, result: ChunkedOrScalarAny, name: str, /) -> Scalar | Self: if isinstance(result, pa.Scalar): return ArrowScalar.from_native(result, name, version=self.version) return self.from_native(result, name or self.name, self.version) - def _dispatch_expr( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowSeries: + def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: """Use instead of `_dispatch` *iff* an operation isn't natively supported on `ChunkedArray`. There is no need to broadcast, as they may have a cheaper impl elsewhere (`CompliantScalar` or `ArrowScalar`). @@ -216,16 +200,16 @@ def _dispatch_expr( Mainly for the benefit of a type checker, but the equivalent `ArrowScalar._dispatch_expr` will raise if the assumption fails. """ - return self._dispatch(node, frame, name).to_series() + return node.dispatch(self, frame, name).to_series() @property def native(self) -> ChunkedArrayAny: return self._evaluated.native - def to_series(self) -> ArrowSeries: + def to_series(self) -> Series: return self._evaluated - def broadcast(self, length: int, /) -> ArrowSeries: + def broadcast(self, length: int, /) -> Series: if (actual_len := len(self)) != length: msg = f"Expected object of length {length}, got {actual_len}." raise ShapeError(msg) @@ -234,25 +218,24 @@ def broadcast(self, length: int, /) -> ArrowSeries: def __len__(self) -> int: return len(self._evaluated) - def sort(self, node: expr.Sort, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def sort(self, node: expr.Sort, frame: Frame, name: str) -> Expr: native = self._dispatch_expr(node.expr, frame, name).native sorted_indices = pc.array_sort_indices(native, options=node.options.to_arrow()) return self._with_native(native.take(sorted_indices), name) - def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def sort_by(self, node: expr.SortBy, frame: Frame, name: str) -> Expr: series = self._dispatch_expr(node.expr, frame, name) by = ( self._dispatch_expr(e, frame, f"_{idx}") for idx, e in enumerate(node.by) ) - ns = self.__narwhals_namespace__() - df = ns._concat_horizontal((series, *by)) + df = namespace(self)._concat_horizontal((series, *by)) names = df.columns[1:] indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) result: ChunkedArrayAny = df.native.column(0).take(indices) return self._with_native(result, name) - def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def filter(self, node: expr.Filter, frame: Frame, name: str) -> Expr: return self._with_native( self._dispatch_expr(node.expr, frame, name).native.filter( self._dispatch_expr(node.by, frame, name).native @@ -260,49 +243,49 @@ def filter(self, node: expr.Filter, frame: ArrowDataFrame, name: str) -> ArrowEx name, ) - def first(self, node: First, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def first(self, node: First, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[0] if len(prev) else lit(None, native.type) + result = native[0] if len(prev) else fn.lit(None, native.type) return self._with_native(result, name) - def last(self, node: Last, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def last(self, node: Last, frame: Frame, name: str) -> Scalar: prev = self._dispatch_expr(node.expr, frame, name) native = prev.native - result = native[height - 1] if (height := len(prev)) else lit(None, native.type) + result = native[len_ - 1] if (len_ := len(prev)) else fn.lit(None, native.type) return self._with_native(result, name) - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native result = pc.index(native, fn.min_(native)) return self._with_native(result, name) - def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: native = self._dispatch_expr(node.expr, frame, name).native result: NativeScalar = pc.index(native, fn.max_(native)) return self._with_native(result, name) - def sum(self, node: Sum, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def sum(self, node: Sum, frame: Frame, name: str) -> Scalar: result = fn.sum_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: result = fn.n_unique(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def std(self, node: Std, frame: Frame, name: str) -> Scalar: result = fn.std( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) - def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def var(self, node: Var, frame: Frame, name: str) -> Scalar: result = fn.var( self._dispatch_expr(node.expr, frame, name).native, ddof=node.ddof ) return self._with_native(result, name) - def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def quantile(self, node: Quantile, frame: Frame, name: str) -> Scalar: result = fn.quantile( self._dispatch_expr(node.expr, frame, name).native, q=node.quantile, @@ -310,23 +293,23 @@ def quantile(self, node: Quantile, frame: ArrowDataFrame, name: str) -> ArrowSca )[0] return self._with_native(result, name) - def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def count(self, node: Count, frame: Frame, name: str) -> Scalar: result = fn.count(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def max(self, node: Max, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def max(self, node: Max, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.max_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def mean(self, node: Mean, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def mean(self, node: Mean, frame: Frame, name: str) -> Scalar: result = fn.mean(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def median(self, node: Median, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def median(self, node: Median, frame: Frame, name: str) -> Scalar: result = fn.median(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) - def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def min(self, node: Min, frame: Frame, name: str) -> Scalar: result: NativeScalar = fn.min_(self._dispatch_expr(node.expr, frame, name).native) return self._with_native(result, name) @@ -336,12 +319,12 @@ def min(self, node: Min, frame: ArrowDataFrame, name: str) -> ArrowScalar: # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main # - [ ] `rolling_expr` has 4 variants - def over(self, node: WindowExpr, frame: ArrowDataFrame, name: str) -> Self: + def over(self, node: WindowExpr, frame: Frame, name: str) -> Self: raise NotImplementedError def over_ordered( - self, node: OrderedWindowExpr, frame: ArrowDataFrame, name: str - ) -> Self | ArrowScalar: + self, node: OrderedWindowExpr, frame: Frame, name: str + ) -> Self | Scalar: if node.partition_by: msg = f"Need to implement `group_by`, `join` for:\n{node!r}" raise NotImplementedError(msg) @@ -351,7 +334,7 @@ def over_ordered( options = node.sort_options.to_multiple(len(node.order_by)) idx_name = generate_temporary_column_name(8, frame.columns) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) - evaluated = self._dispatch(node.expr, sorted_context.drop([idx_name]), name) + evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) if isinstance(evaluated, ArrowScalar): # NOTE: We're already sorted, defer broadcasting to the outer context # Wouldn't be suitable for partitions, but will be fine here @@ -364,28 +347,28 @@ def over_ordered( return self._with_native(result, name) # NOTE: Can't implement in `EagerExpr`, since it doesn't derive `ExprDispatch` - def map_batches(self, node: AnonymousExpr, frame: ArrowDataFrame, name: str) -> Self: + def map_batches(self, node: AnonymousExpr, frame: Frame, name: str) -> Self: if node.is_scalar: - # NOTE: Just trying to avoid redoing the whole API for `ArrowSeries` + # NOTE: Just trying to avoid redoing the whole API for `Series` msg = "Only elementwise is currently supported" raise NotImplementedError(msg) series = self._dispatch_expr(node.input[0], frame, name) udf = node.function.function - result: ArrowSeries | Into1DArray = udf(series) + result: Series | Into1DArray = udf(series) if not fn.is_series(result): - result = ArrowSeries.from_numpy(result, name, version=self.version) + result = Series.from_numpy(result, name, version=self.version) if dtype := node.function.return_dtype: result = result.cast(dtype) return self.from_series(result) - def rolling_expr(self, node: RollingExpr, frame: ArrowDataFrame, name: str) -> Self: + def rolling_expr(self, node: RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError class ArrowScalar( _ArrowDispatch["ArrowScalar"], _StoresNative[NativeScalar], - EagerScalar["ArrowDataFrame", ArrowSeries], + EagerScalar["Frame", Series], ): _evaluated: NativeScalar _version: Version @@ -416,14 +399,12 @@ def from_python( version: Version = Version.MAIN, ) -> Self: dtype_pa: pa.DataType | None = None - if dtype: - dtype = into_dtype(dtype) - if not isinstance(dtype, version.dtypes.Unknown): - dtype_pa = narwhals_to_native_dtype(dtype, version) - return cls.from_native(lit(value, dtype_pa), name, version) + if dtype and dtype != version.dtypes.Unknown: + dtype_pa = narwhals_to_native_dtype(dtype, version) + return cls.from_native(fn.lit(value, dtype_pa), name, version) @classmethod - def from_series(cls, series: ArrowSeries) -> Self: + def from_series(cls, series: Series) -> Self: if len(series) == 1: return cls.from_native(series.native[0], series.name, series.version) if len(series) == 0: @@ -433,9 +414,7 @@ def from_series(cls, series: ArrowSeries) -> Self: msg = f"Too long {len(series)!r}" raise InvalidOperationError(msg) - def _dispatch_expr( - self, node: ExprIR, frame: ArrowDataFrame, name: str - ) -> ArrowSeries: + def _dispatch_expr(self, node: ExprIR, frame: Frame, name: str) -> Series: msg = f"Expected unreachable, but hit at: {node!r}" raise InvalidOperationError(msg) @@ -446,13 +425,13 @@ def _with_native(self, native: Any, name: str, /) -> Self: def native(self) -> NativeScalar: return self._evaluated - def to_series(self) -> ArrowSeries: + def to_series(self) -> Series: return self.broadcast(1) def to_python(self) -> PythonLiteral: return self.native.as_py() # type: ignore[no-any-return] - def broadcast(self, length: int) -> ArrowSeries: + def broadcast(self, length: int) -> Series: scalar = self.native if length == 1: chunked = fn.chunked_array(scalar) @@ -461,25 +440,25 @@ def broadcast(self, length: int) -> ArrowSeries: # https://github.com/zen-xu/pyarrow-stubs/pull/209 pa_repeat: Incomplete = pa.repeat chunked = fn.chunked_array(pa_repeat(scalar, length)) - return ArrowSeries.from_native(chunked, self.name, version=self.version) + return Series.from_native(chunked, self.name, version=self.version) - def arg_min(self, node: ArgMin, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_min(self, node: ArgMin, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(0), name) - def arg_max(self, node: ArgMax, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def arg_max(self, node: ArgMax, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(0), name) - def n_unique(self, node: NUnique, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def n_unique(self, node: NUnique, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(1), name) - def std(self, node: Std, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def std(self, node: Std, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(None, pa.null()), name) - def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: + def var(self, node: Var, frame: Frame, name: str) -> Scalar: return self._with_native(pa.scalar(None, pa.null()), name) - def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - native = self._dispatch(node.expr, frame, name).native + def count(self, node: Count, frame: Frame, name: str) -> Scalar: + native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) filter = not_implemented() diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e941727d6b..f7bfaaa330 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -7,9 +7,9 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._arrow.utils import narwhals_to_native_dtype +from narwhals._plan._guards import is_tuple_of from narwhals._plan.arrow import functions as fn -from narwhals._plan.arrow.functions import lit -from narwhals._plan.common import collect, is_tuple_of +from narwhals._plan.common import collect from narwhals._plan.literal import is_literal_scalar from narwhals._plan.protocols import EagerNamespace from narwhals._utils import Version @@ -20,79 +20,64 @@ from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan import expr, functions as F - from narwhals._plan.arrow.dataframe import ArrowDataFrame - from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar - from narwhals._plan.arrow.series import ArrowSeries + from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame + from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar + from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.boolean import AllHorizontal, AnyHorizontal - from narwhals._plan.dummy import Series + from narwhals._plan.dummy import Series as NwSeries from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatHorizontal + from narwhals._plan.strings import ConcatStr from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral -class ArrowNamespace( - EagerNamespace["ArrowDataFrame", "ArrowSeries", "ArrowExpr", "ArrowScalar"] -): +class ArrowNamespace(EagerNamespace["Frame", "Series", "Expr", "Scalar"]): def __init__(self, version: Version = Version.MAIN) -> None: self._version = version @property - def _expr(self) -> type[ArrowExpr]: + def _expr(self) -> type[Expr]: from narwhals._plan.arrow.expr import ArrowExpr return ArrowExpr @property - def _scalar(self) -> type[ArrowScalar]: + def _scalar(self) -> type[Scalar]: from narwhals._plan.arrow.expr import ArrowScalar return ArrowScalar @property - def _series(self) -> type[ArrowSeries]: + def _series(self) -> type[Series]: from narwhals._plan.arrow.series import ArrowSeries return ArrowSeries @property - def _dataframe(self) -> type[ArrowDataFrame]: + def _dataframe(self) -> type[Frame]: from narwhals._plan.arrow.dataframe import ArrowDataFrame return ArrowDataFrame - def col(self, node: expr.Column, frame: ArrowDataFrame, name: str) -> ArrowExpr: + def col(self, node: expr.Column, frame: Frame, name: str) -> Expr: return self._expr.from_native( frame.native.column(node.name), name, version=frame.version ) @overload def lit( - self, node: expr.Literal[NonNestedLiteral], frame: ArrowDataFrame, name: str - ) -> ArrowScalar: ... - + self, node: expr.Literal[NonNestedLiteral], frame: Frame, name: str + ) -> Scalar: ... @overload def lit( - self, - node: expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, - name: str, - ) -> ArrowExpr: ... - - @overload - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, - name: str, - ) -> ArrowExpr | ArrowScalar: ... - + self, node: expr.Literal[NwSeries[ChunkedArrayAny]], frame: Frame, name: str + ) -> Expr: ... def lit( self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[ChunkedArrayAny]], - frame: ArrowDataFrame, + node: expr.Literal[NonNestedLiteral] | expr.Literal[NwSeries[ChunkedArrayAny]], + frame: Frame, name: str, - ) -> ArrowExpr | ArrowScalar: + ) -> Expr | Scalar: if is_literal_scalar(node): return self._scalar.from_python( node.unwrap(), name, dtype=node.dtype, version=frame.version @@ -106,13 +91,11 @@ def lit( # https://github.com/narwhals-dev/narwhals/pull/2719 def _horizontal_function( self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None - ) -> Callable[[FunctionExpr[Any], ArrowDataFrame, str], ArrowExpr | ArrowScalar]: - def func( - node: FunctionExpr[Any], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + ) -> Callable[[FunctionExpr[Any], Frame, str], Expr | Scalar]: + def func(node: FunctionExpr[Any], frame: Frame, name: str) -> Expr | Scalar: it = (self._expr.from_ir(e, frame, name).native for e in node.input) if fill is not None: - it = (pc.fill_null(native, lit(fill)) for native in it) + it = (pc.fill_null(native, fn.lit(fill)) for native in it) result = reduce(fn_native, it) if isinstance(result, pa.Scalar): return self._scalar.from_native(result, name, self.version) @@ -121,36 +104,36 @@ def func( return func def any_horizontal( - self, node: FunctionExpr[AnyHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[AnyHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.or_)(node, frame, name) def all_horizontal( - self, node: FunctionExpr[AllHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[AllHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.and_)(node, frame, name) def sum_horizontal( - self, node: FunctionExpr[F.SumHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.SumHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.add, fill=0)(node, frame, name) def min_horizontal( - self, node: FunctionExpr[F.MinHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MinHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.min_horizontal)(node, frame, name) def max_horizontal( - self, node: FunctionExpr[F.MaxHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MaxHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: return self._horizontal_function(fn.max_horizontal)(node, frame, name) def mean_horizontal( - self, node: FunctionExpr[F.MeanHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[F.MeanHorizontal], frame: Frame, name: str + ) -> Expr | Scalar: int64 = pa.int64() inputs = [self._expr.from_ir(e, frame, name).native for e in node.input] - filled = (pc.fill_null(native, lit(0)) for native in inputs) + filled = (pc.fill_null(native, fn.lit(0)) for native in inputs) # NOTE: `mypy` doesn't like that `add` is overloaded sum_not_null = reduce( fn.add, # type: ignore[arg-type] @@ -162,8 +145,8 @@ def mean_horizontal( return self._expr.from_native(result, name, self.version) def concat_str( - self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str - ) -> ArrowExpr | ArrowScalar: + self, node: FunctionExpr[ConcatStr], frame: Frame, name: str + ) -> Expr | Scalar: exprs = (self._expr.from_ir(e, frame, name) for e in node.input) aligned = (ser.native for ser in self._expr.align(exprs)) separator = node.function.separator @@ -173,9 +156,7 @@ def concat_str( return self._scalar.from_native(result, name, self.version) return self._expr.from_native(result, name, self.version) - def int_range( - self, node: RangeExpr[IntRange], frame: ArrowDataFrame, name: str - ) -> ArrowExpr: + def int_range(self, node: RangeExpr[IntRange], frame: Frame, name: str) -> Expr: start_: PythonLiteral end_: PythonLiteral start, end = node.function.unwrap_input(node) @@ -209,21 +190,12 @@ def int_range( raise InvalidOperationError(msg) @overload - def concat( - self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod - ) -> ArrowDataFrame: ... - + def concat(self, items: Iterable[Frame], *, how: ConcatMethod) -> Frame: ... @overload + def concat(self, items: Iterable[Series], *, how: Literal["vertical"]) -> Series: ... def concat( - self, items: Iterable[ArrowSeries], *, how: Literal["vertical"] - ) -> ArrowSeries: ... - - def concat( - self, - items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries], - *, - how: ConcatMethod, - ) -> ArrowDataFrame | ArrowSeries: + self, items: Iterable[Frame | Series], *, how: ConcatMethod + ) -> Frame | Series: if how == "vertical": return self._concat_vertical(items) if how == "horizontal": @@ -232,20 +204,16 @@ def concat( first = next(it) if self._is_series(first): raise TypeError(first) - dfs = cast("Sequence[ArrowDataFrame]", (first, *it)) + dfs = cast("Sequence[Frame]", (first, *it)) return self._concat_diagonal(dfs) - def _concat_diagonal(self, items: Iterable[ArrowDataFrame]) -> ArrowDataFrame: + def _concat_diagonal(self, items: Iterable[Frame]) -> Frame: return self._dataframe.from_native( fn.concat_vertical_table(df.native for df in items), self.version ) - def _concat_horizontal( - self, items: Iterable[ArrowDataFrame | ArrowSeries] - ) -> ArrowDataFrame: - def gen( - objs: Iterable[ArrowDataFrame | ArrowSeries], - ) -> Iterator[tuple[ChunkedArrayAny, str]]: + def _concat_horizontal(self, items: Iterable[Frame | Series]) -> Frame: + def gen(objs: Iterable[Frame | Series]) -> Iterator[tuple[ChunkedArrayAny, str]]: for item in objs: if self._is_series(item): yield item.native, item.name @@ -256,9 +224,7 @@ def gen( native = pa.Table.from_arrays(arrays, list(names)) return self._dataframe.from_native(native, self.version) - def _concat_vertical( - self, items: Iterable[ArrowDataFrame] | Iterable[ArrowSeries] - ) -> ArrowDataFrame | ArrowSeries: + def _concat_vertical(self, items: Iterable[Frame | Series]) -> Frame | Series: collected = collect(items) if is_tuple_of(collected, self._series): sers = collected diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 7a3902cb6b..23f7d27dd3 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -4,8 +4,8 @@ # - Any import typing as t -from narwhals._plan.common import Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: @@ -21,22 +21,22 @@ ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") -class BooleanFunction(Function): ... - - +# fmt: off +class BooleanFunction(Function, options=FunctionOptions.elementwise): ... class All(BooleanFunction, options=FunctionOptions.aggregation): ... - - -class AllHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... - - +class AllHorizontal(HorizontalFunction, BooleanFunction): ... class Any(BooleanFunction, options=FunctionOptions.aggregation): ... - - -class AnyHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... - - -class IsBetween(BooleanFunction, options=FunctionOptions.elementwise): +class AnyHorizontal(HorizontalFunction, BooleanFunction): ... +class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsFinite(BooleanFunction): ... +class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... +class IsNan(BooleanFunction): ... +class IsNull(BooleanFunction): ... +class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... +class Not(BooleanFunction, config=FEOptions.renamed("not_")): ... +# fmt: on +class IsBetween(BooleanFunction): """N-ary (expr, lower_bound, upper_bound).""" __slots__ = ("closed",) @@ -47,16 +47,7 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, Exp return expr, lower_bound, upper_bound -class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsFinite(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsIn(BooleanFunction, t.Generic[OtherT], options=FunctionOptions.elementwise): +class IsIn(BooleanFunction, t.Generic[OtherT]): __slots__ = ("other",) other: OtherT @@ -90,19 +81,3 @@ def __init__(self, *, other: ExprT) -> None: "You should provide an iterable instead." ) raise NotImplementedError(msg) - - -class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class IsNan(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsNull(BooleanFunction, options=FunctionOptions.elementwise): ... - - -class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... - - -class Not(BooleanFunction, options=FunctionOptions.elementwise): - """`__invert__`.""" diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 7fb58367f9..13791bed16 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -1,23 +1,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr +# fmt: off class CategoricalFunction(Function, accessor="cat"): ... - - -class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): ... - - +class GetCategories(CategoricalFunction): ... +# fmt: on class IRCatNamespace(IRNamespace): - def get_categories(self) -> GetCategories: - return GetCategories() + get_categories: ClassVar = GetCategories class ExprCatNamespace(ExprNamespace[IRCatNamespace]): diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 3e77d493d0..f73f21b26b 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -5,73 +5,40 @@ import sys from collections.abc import Iterable from decimal import Decimal +from operator import attrgetter from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload +from narwhals._plan._guards import is_function_expr, is_iterable_reject, is_literal +from narwhals._plan._immutable import Immutable +from narwhals._plan.options import ExprIROptions, FEOptions, FunctionOptions from narwhals._plan.typing import ( Accessor, DTypeT, ExprIRT, ExprIRT2, + FunctionT, IRNamespaceT, MapIR, NamedOrExprIRT, - NativeSeriesT, NonNestedDTypeT, + OneOrIterable, Seq, ) -from narwhals._utils import _hasattr_static from narwhals.dtypes import DType from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable, Literal - - from typing_extensions import Never, Self, TypeIs, dataclass_transform - - from narwhals._plan import expr - from narwhals._plan.dummy import Expr, Selector, Series - from narwhals._plan.expr import ( - AggExpr, - Alias, - BinaryExpr, - Cast, - Column, - FunctionExpr, - WindowExpr, - ) + from typing import Any, Callable + + from typing_extensions import Self, TypeAlias + + from narwhals._plan.dummy import Expr, Selector + from narwhals._plan.expr import Alias, Cast, Column, FunctionExpr from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.options import FunctionOptions - from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral -else: - # NOTE: This isn't important to the proposal, just wanted IDE support - # for the **temporary** constructors. - # It is interesting how much boilerplate this avoids though 🤔 - # https://docs.python.org/3/library/typing.html#typing.dataclass_transform - def dataclass_transform( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - frozen_default: bool = False, - field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), - **kwargs: Any, - ) -> Callable[[T], T]: - def decorator(cls_or_fn: T) -> T: - cls_or_fn.__dataclass_transform__ = { - "eq_default": eq_default, - "order_default": order_default, - "kw_only_default": kw_only_default, - "frozen_default": frozen_default, - "field_specifiers": field_specifiers, - "kwargs": kwargs, - } - return cls_or_fn - - return decorator - if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 @@ -87,127 +54,109 @@ def replace(obj: T, /, **changes: Any) -> T: T = TypeVar("T") +Incomplete: TypeAlias = "Any" -_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase, camelCase string to snake_case. -@dataclass_transform(kw_only_default=True, frozen_default=True) -class Immutable: - __slots__ = (_IMMUTABLE_HASH_NAME,) - __immutable_hash_value__: int + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - @property - def __immutable_keys__(self) -> Iterator[str]: - slots: tuple[str, ...] = self.__slots__ - for name in slots: - if name != _IMMUTABLE_HASH_NAME: - yield name - @property - def __immutable_values__(self) -> Iterator[Any]: - for name in self.__immutable_keys__: - yield getattr(self, name) +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - @property - def __immutable_items__(self) -> Iterator[tuple[str, Any]]: - for name in self.__immutable_keys__: - yield name, getattr(self, name) - @property - def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *self.__immutable_values__)) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ - - def __setattr__(self, name: str, value: Never) -> Never: - msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." - raise AttributeError(msg) - - def __replace__(self, **changes: Any) -> Self: - """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 - if len(changes) == 1: - k_new, v_new = next(iter(changes.items())) - # NOTE: Will trigger an attribute error if invalid name - if getattr(self, k_new) == v_new: - return self - changed = dict(self.__immutable_items__) - # Now we *don't* need to check the key is valid - changed[k_new] = v_new - else: - changed = dict(self.__immutable_items__) - changed |= changes - return type(self)(**changed) +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" - def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: - super().__init_subclass__(*args, **kwds) - if cls.__slots__: - ... - else: - cls.__slots__ = () - - def __hash__(self) -> int: - return self.__immutable_hash__ - - def __eq__(self, other: object) -> bool: - if self is other: - return True - if type(self) is not type(other): - return False - return all( - getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ - ) - - def __str__(self) -> str: - # NOTE: Debug repr, closer to constructor - fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) - return f"{type(self).__name__}({fields})" - - def __init__(self, **kwds: Any) -> None: - # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! - # Just need a quick way to demonstrate `ExprIR` and interactions - required: set[str] = set(self.__immutable_keys__) - if not required and not kwds: - # NOTE: Fastpath for empty slots - ... - elif required == set(kwds): - # NOTE: Everything is as expected - for name, value in kwds.items(): - object.__setattr__(self, name, value) - elif missing := required.difference(kwds): - msg = ( - f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" - f"but missing values for {sorted(missing)!r}" - ) - raise TypeError(msg) - else: - extra = set(kwds).difference(required) + +def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or _pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def _dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: + getter = attrgetter(_dispatch_method_name(tp)) + if tp.__expr_ir_config__.origin == "expr": + return getter + return lambda ctx: getter(ctx.__narwhals_namespace__()) + + +def _dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + if not tp.__expr_ir_config__.allow_dispatch: + + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: msg = ( - f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" - f"but got unknown arguments {sorted(extra)!r}" + f"{tp.__name__!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" ) raise TypeError(msg) + return _ + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + -def _field_str(name: str, value: Any) -> str: - if isinstance(value, tuple): - inner = ", ".join(f"{v}" for v in value) - return f"{name}=[{inner}]" - if isinstance(value, str): - return f"{name}={value!r}" - return f"{name}={value}" +def _dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" + _child: ClassVar[Seq[str]] = () + """Nested node names, in iteration order.""" + + __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] + ] + + def __init_subclass__( + cls: type[Self], + *args: Any, + child: Seq[str] = (), + config: ExprIROptions | None = None, + **kwds: Any, + ) -> None: + super().__init_subclass__(*args, **kwds) + if child: + cls._child = child + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + ) -> R_co: + """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import dummy - if version is Version.MAIN: - return dummy.Expr._from_ir(self) - return dummy.ExprV1._from_ir(self) + tp = dummy.Expr if version is Version.MAIN else dummy.ExprV1 + return tp._from_ir(self) @property def is_scalar(self) -> bool: @@ -221,8 +170,11 @@ def map_ir(self, function: MapIR, /) -> ExprIR: [`polars_plan::plans::iterator::Expr.map_expr`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/iterator.rs#L152-L159 [`polars_plan::plans::visitor::visitors`]: https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/visitor/visitors.rs """ - msg = f"Need to handle recursive visiting first for {type(self).__qualname__!r}!\n\n{self!r}" - raise NotImplementedError(msg) + if not self._child: + return function(self) + children = ((name, getattr(self, name)) for name in self._child) + changed = {name: _map_ir_child(child, function) for name, child in children} + return function(replace(self, **changed)) def iter_left(self) -> Iterator[ExprIR]: """Yield nodes root->leaf. @@ -247,6 +199,13 @@ def iter_left(self) -> Iterator[ExprIR]: >>> list(d._ir.iter_left()) [col('a'), col('a').alias('b'), col('a').alias('b').min(), col('a').alias('b').min().alias('c'), col('e'), col('f'), col('a').alias('b').min().alias('c').over([col('e'), col('f')])] """ + for name in self._child: + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_left() + else: + for node in child: + yield from node.iter_left() yield self def iter_right(self) -> Iterator[ExprIR]: @@ -276,6 +235,13 @@ def iter_right(self) -> Iterator[ExprIR]: [col('a').alias('b').min().alias('c').over([col('e'), col('f')]), col('f'), col('e'), col('a').alias('b').min().alias('c'), col('a').alias('b').min(), col('a').alias('b'), col('a')] """ yield self + for name in reversed(self._child): + child: ExprIR | Seq[ExprIR] = getattr(self, name) + if isinstance(child, ExprIR): + yield from child.iter_right() + else: + for node in reversed(child): + yield from node.iter_right() def iter_root_names(self) -> Iterator[ExprIR]: """Override for different iteration behavior in `ExprIR.meta.root_names`. @@ -313,7 +279,7 @@ def _repr_html_(self) -> str: return self.__repr__() -class SelectorIR(ExprIR): +class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: from narwhals._plan import dummy @@ -366,12 +332,9 @@ def from_ir(expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: """ return NamedIR(expr=expr, name=expr.meta.output_name(raise_if_undetermined=True)) - def map_ir(self, function: MapIR, /) -> NamedIR[ExprIR]: + def map_ir(self, function: MapIR, /) -> Self: """**WARNING**: don't use renaming ops here, or `self.name` is invalid.""" - return self.with_expr(function(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIRT2, /) -> NamedIR[ExprIRT2]: - return cast("NamedIR[ExprIRT2]", replace(self, expr=expr)) + return replace(self, expr=function(self.expr.map_ir(function))) def __repr__(self) -> str: return f"{self.name}={self.expr!r}" @@ -397,7 +360,7 @@ def is_elementwise_top_level(self) -> bool: return ir.options.is_elementwise() if is_literal(ir): return ir.is_scalar - return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.Ternary, expr.Cast)) + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) class IRNamespace(Immutable): @@ -428,26 +391,19 @@ def _with_unary(self, function: Function, /) -> Expr: return self._expr._with_unary(function) -def _function_options_default() -> FunctionOptions: - from narwhals._plan.options import FunctionOptions - - return FunctionOptions.default() - - class Function(Immutable): """Shared by expr functions and namespace functions. - Only valid in `FunctionExpr.function` - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 """ - _accessor: ClassVar[Accessor | None] = None - """Namespace accessor name, if any.""" - _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( - _function_options_default + FunctionOptions.default ) + __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] + ] @property def function_options(self) -> FunctionOptions: @@ -463,136 +419,29 @@ def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: return FunctionExpr(input=inputs, function=self, options=self.function_options) def __init_subclass__( - cls, + cls: type[Self], *args: Any, accessor: Accessor | None = None, options: Callable[[], FunctionOptions] | None = None, + config: FEOptions | None = None, **kwds: Any, ) -> None: super().__init_subclass__(*args, **kwds) if accessor: - cls._accessor = accessor + config = replace(config or FEOptions.default(), accessor_name=accessor) if options: cls._function_options = staticmethod(options) + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) def __repr__(self) -> str: - return _function_repr(type(self)) - - -# TODO @dangotbanned: Add caching strategy? -def _function_repr(tp: type[Function], /) -> str: - name = _pascal_to_snake_case(tp.__name__) - return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name - - -def _pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" - - -_NON_NESTED_LITERAL_TPS = ( - int, - float, - str, - dt.date, - dt.time, - dt.timedelta, - bytes, - Decimal, -) - - -def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]: - return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS) - - -def is_expr(obj: Any) -> TypeIs[Expr]: - from narwhals._plan.dummy import Expr - - return isinstance(obj, Expr) - - -def is_column(obj: Any) -> TypeIs[Expr]: - """Indicate if the given object is a basic/unaliased column. - - https://github.com/pola-rs/polars/blob/a3d6a3a7863b4d42e720a05df69ff6b6f5fc551f/py-polars/polars/_utils/various.py#L164-L168. - """ - return is_expr(obj) and obj.meta.is_column() + return _dispatch_method_name(type(self)) -def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]: - from narwhals._plan.dummy import Series - - return isinstance(obj, Series) - - -def is_compliant_series( - obj: CompliantSeries[NativeSeriesT] | Any, -) -> TypeIs[CompliantSeries[NativeSeriesT]]: - return _hasattr_static(obj, "__narwhals_series__") - - -def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]: - from narwhals._plan.dummy import Series - - return isinstance(obj, (str, bytes, Series)) or is_compliant_series(obj) - - -def is_window_expr(obj: Any) -> TypeIs[WindowExpr]: - from narwhals._plan.expr import WindowExpr - - return isinstance(obj, WindowExpr) - - -def is_function_expr(obj: Any) -> TypeIs[FunctionExpr[Any]]: - from narwhals._plan.expr import FunctionExpr - - return isinstance(obj, FunctionExpr) - - -def is_binary_expr(obj: Any) -> TypeIs[BinaryExpr]: - from narwhals._plan.expr import BinaryExpr - - return isinstance(obj, BinaryExpr) - - -def is_agg_expr(obj: Any) -> TypeIs[AggExpr]: - from narwhals._plan.expr import AggExpr - - return isinstance(obj, AggExpr) - - -def is_aggregation(obj: Any) -> TypeIs[AggExpr | FunctionExpr[Any]]: - """Superset of `ExprIR.is_scalar`, excludes literals & len.""" - return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar) - - -def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]: - from narwhals._plan import expr - - return isinstance(obj, expr.Literal) - - -def is_horizontal_reduction(obj: FunctionExpr[Any] | Any) -> TypeIs[FunctionExpr[Any]]: - return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() - - -def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: - return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) +class HorizontalFunction( + Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() +): ... def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: @@ -641,10 +490,14 @@ def map_ir( return origin.map_ir(function) +def _map_ir_child(obj: ExprIR | Seq[ExprIR], fn: MapIR, /) -> ExprIR | Seq[ExprIR]: + return obj.map_ir(fn) if isinstance(obj, ExprIR) else tuple(e.map_ir(fn) for e in obj) + + # TODO @dangotbanned: Review again and try to work around (https://github.com/microsoft/pyright/issues/10673#issuecomment-3033789021) # The issue is `T` possibly being `Iterable` # Ignoring here still leaks the issue to the caller, where you need to annotate the base case -def flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]: +def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: """Fully unwrap all levels of nesting. Aiming to reduce the chances of passing an unhashable argument. diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 85b8ac2ad7..ab89b97c96 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -3,17 +3,12 @@ import builtins import typing as t -from narwhals._plan import boolean, expr, expr_parsing as parse, functions as F -from narwhals._plan.common import ( - into_dtype, - is_non_nested_literal, - is_series, - py_to_narwhals_dtype, -) +from narwhals._plan import _guards, boolean, expr, expr_parsing as parse, functions as F +from narwhals._plan.common import into_dtype, py_to_narwhals_dtype from narwhals._plan.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange -from narwhals._plan.strings import ConcatHorizontal +from narwhals._plan.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import Version, flatten @@ -39,9 +34,9 @@ def nth(*indices: int | t.Sequence[int]) -> Expr: def lit( value: NonNestedLiteral | Series[NativeSeriesT], dtype: IntoDType | None = None ) -> Expr: - if is_series(value): + if _guards.is_series(value): return SeriesLiteral(value=value).to_literal().to_narwhals() - if not is_non_nested_literal(value): + if not _guards.is_non_nested_literal(value): msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}." raise TypeError(msg) if dtype is None: @@ -121,7 +116,7 @@ def concat_str( ) -> Expr: it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( - ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) + ConcatStr(separator=separator, ignore_nulls=ignore_nulls) .to_function_expr(*it) .to_narwhals() ) diff --git a/narwhals/_plan/dummy.py b/narwhals/_plan/dummy.py index 6ef174ec06..0a1e469917 100644 --- a/narwhals/_plan/dummy.py +++ b/narwhals/_plan/dummy.py @@ -3,8 +3,8 @@ from __future__ import annotations import math -import typing as t -from typing import TYPE_CHECKING, Generic +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload from narwhals._plan import ( aggregation as agg, @@ -15,7 +15,8 @@ functions as F, operators as ops, ) -from narwhals._plan.common import NamedIR, into_dtype, is_column, is_expr, is_series +from narwhals._plan._guards import is_column, is_expr, is_series +from narwhals._plan.common import into_dtype from narwhals._plan.contexts import ExprContext from narwhals._plan.options import ( EWMOptions, @@ -33,13 +34,11 @@ from narwhals.schema import Schema if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - import pyarrow as pa from typing_extensions import Never, Self from narwhals._plan.categorical import ExprCatNamespace - from narwhals._plan.common import ExprIR, Function + from narwhals._plan.common import ExprIR, Function, NamedIR from narwhals._plan.lists import ExprListNamespace from narwhals._plan.meta import IRMetaNamespace from narwhals._plan.name import ExprNameNamespace @@ -52,7 +51,7 @@ from narwhals._plan.strings import ExprStringNamespace from narwhals._plan.struct import ExprStructNamespace from narwhals._plan.temporal import ExprDateTimeNamespace - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq, Udf + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq, Udf from narwhals.dtypes import DType from narwhals.typing import ( ClosedInterval, @@ -69,10 +68,10 @@ # NOTE: Trying to keep consistent logic between `DataFrame.sort` and `Expr.sort_by` def _parse_sort_by( - by: IntoExpr | Iterable[IntoExpr] = (), + by: OneOrIterable[IntoExpr] = (), *more_by: IntoExpr, - descending: bool | t.Iterable[bool] = False, - nulls_last: bool | t.Iterable[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> tuple[Seq[ExprIR], SortMultipleOptions]: sort_by = parse.parse_into_seq_of_expr_ir(by, *more_by) if length_changing := next((e for e in sort_by if e.is_scalar), None): @@ -86,7 +85,7 @@ def _parse_sort_by( # Entirely ignoring namespace + function binding class Expr: _ir: ExprIR - _version: t.ClassVar[Version] = Version.MAIN + _version: ClassVar[Version] = Version.MAIN def __repr__(self) -> str: return f"nw._plan.Expr({self.version.name.lower()}):\n{self._ir!r}" @@ -114,7 +113,7 @@ def alias(self, name: str) -> Self: def cast(self, dtype: IntoDType) -> Self: return self._from_ir(self._ir.cast(into_dtype(dtype))) - def exclude(self, *names: str | t.Iterable[str]) -> Self: + def exclude(self, *names: OneOrIterable[str]) -> Self: return self._from_ir(expr.Exclude.from_names(self._ir, *names)) def count(self) -> Self: @@ -165,8 +164,8 @@ def quantile( def over( self, - *partition_by: IntoExpr | t.Iterable[IntoExpr], - order_by: IntoExpr | t.Iterable[IntoExpr] = None, + *partition_by: OneOrIterable[IntoExpr], + order_by: OneOrIterable[IntoExpr] = None, descending: bool = False, nulls_last: bool = False, ) -> Self: @@ -191,10 +190,10 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: def sort_by( self, - by: IntoExpr | t.Iterable[IntoExpr], + by: OneOrIterable[IntoExpr], *more_by: IntoExpr, - descending: bool | t.Iterable[bool] = False, - nulls_last: bool | t.Iterable[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> Self: keys, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last @@ -202,9 +201,7 @@ def sort_by( return self._from_ir(expr.SortBy(expr=self._ir, by=keys, options=opts)) def filter( - self, - *predicates: IntoExprColumn | t.Iterable[IntoExprColumn], - **constraints: t.Any, + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> Self: by = parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return self._from_ir(expr.Filter(expr=self._ir, by=by)) @@ -217,7 +214,7 @@ def abs(self) -> Self: def hist( self, - bins: t.Sequence[float] | None = None, + bins: Sequence[float] | None = None, *, bin_count: int | None = None, include_breakpoint: bool = True, @@ -371,20 +368,20 @@ def ewm_mean( def replace_strict( self, - old: t.Sequence[t.Any] | t.Mapping[t.Any, t.Any], - new: t.Sequence[t.Any] | None = None, + old: Sequence[Any] | Mapping[Any, Any], + new: Sequence[Any] | None = None, *, return_dtype: IntoDType | None = None, ) -> Self: - before: Seq[t.Any] - after: Seq[t.Any] + before: Seq[Any] + after: Seq[Any] if new is None: - if not isinstance(old, t.Mapping): + if not isinstance(old, Mapping): msg = "`new` argument is required if `old` argument is not a Mapping type" raise TypeError(msg) before = tuple(old) after = tuple(old.values()) - elif isinstance(old, t.Mapping): + elif isinstance(old, Mapping): msg = "`new` argument cannot be used if `old` argument is a Mapping type" raise TypeError(msg) else: @@ -455,10 +452,10 @@ def is_between( boolean.IsBetween(closed=closed).to_function_expr(self._ir, *it) ) - def is_in(self, other: t.Iterable[t.Any]) -> Self: + def is_in(self, other: Iterable[Any]) -> Self: if is_series(other): return self._with_unary(boolean.IsInSeries.from_series(other)) - if isinstance(other, t.Iterable): + if isinstance(other, Iterable): return self._with_unary(boolean.IsInSeq.from_iterable(other)) if is_expr(other): return self._with_unary(boolean.IsInExpr(other=other._ir)) @@ -627,9 +624,9 @@ def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override] def _to_expr(self) -> Expr: return self._ir.to_narwhals(self.version) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __or__(self, other: Self) -> Self: ... - @t.overload + @overload def __or__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): @@ -637,9 +634,9 @@ def __or__(self, other: IntoExprColumn | int | bool) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() | other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __and__(self, other: Self) -> Self: ... - @t.overload + @overload def __and__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): @@ -649,9 +646,9 @@ def __and__(self, other: IntoExprColumn | int | bool) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() & other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __sub__(self, other: Self) -> Self: ... - @t.overload + @overload def __sub__(self, other: IntoExpr) -> Expr: ... def __sub__(self, other: IntoExpr) -> Self | Expr: if isinstance(other, type(self)): @@ -659,9 +656,9 @@ def __sub__(self, other: IntoExpr) -> Self | Expr: return self._from_ir(op.to_binary_selector(self._ir, other._ir)) return self._to_expr() - other - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __xor__(self, other: Self) -> Self: ... - @t.overload + @overload def __xor__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if isinstance(other, type(self)): @@ -672,41 +669,41 @@ def __xor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: def __invert__(self) -> Self: return self._from_ir(expr.InvertSelector(selector=self._ir)) - def __add__(self, other: t.Any) -> Expr: # type: ignore[override] + def __add__(self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, type(self)): msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" raise TypeError(msg) return self._to_expr() + other # type: ignore[no-any-return] - def __radd__(self, other: t.Any) -> Never: + def __radd__(self, other: Any) -> Never: msg = "unsupported operand type(s) for op: ('Expr' + 'Selector')" raise TypeError(msg) - def __rsub__(self, other: t.Any) -> Never: + def __rsub__(self, other: Any) -> Never: msg = "unsupported operand type(s) for op: ('Expr' - 'Selector')" raise TypeError(msg) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __rand__(self, other: Self) -> Self: ... - @t.overload + @overload def __rand__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __rand__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) & self return self._to_expr().__rand__(other) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __ror__(self, other: Self) -> Self: ... - @t.overload + @overload def __ror__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __ror__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): return by_name(name) | self return self._to_expr().__ror__(other) - @t.overload # type: ignore[override] + @overload # type: ignore[override] def __rxor__(self, other: Self) -> Self: ... - @t.overload + @overload def __rxor__(self, other: IntoExprColumn | int | bool) -> Expr: ... def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: if is_column(other) and (name := other.meta.output_name()): @@ -715,16 +712,16 @@ def __rxor__(self, other: IntoExprColumn | int | bool) -> Self | Expr: class ExprV1(Expr): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 class SelectorV1(Selector): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 class BaseFrame(Generic[NativeFrameT]): - _compliant: CompliantBaseFrame[t.Any, NativeFrameT] - _version: t.ClassVar[Version] = Version.MAIN + _compliant: CompliantBaseFrame[Any, NativeFrameT] + _version: ClassVar[Version] = Version.MAIN @property def version(self) -> Version: @@ -742,13 +739,11 @@ def __repr__(self) -> str: # pragma: no cover return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) @classmethod - def from_native(cls, native: t.Any, /) -> Self: + def from_native(cls, native: Any, /) -> Self: raise NotImplementedError @classmethod - def _from_compliant( - cls, compliant: CompliantBaseFrame[t.Any, NativeFrameT], / - ) -> Self: + def _from_compliant(cls, compliant: CompliantBaseFrame[Any, NativeFrameT], /) -> Self: obj = cls.__new__(cls) obj._compliant = compliant return obj @@ -758,8 +753,8 @@ def to_native(self) -> NativeFrameT: def _project( self, - exprs: tuple[IntoExpr | Iterable[IntoExpr], ...], - named_exprs: dict[str, t.Any], + exprs: tuple[OneOrIterable[IntoExpr], ...], + named_exprs: dict[str, Any], context: ExprContext, /, ) -> tuple[Seq[NamedIR[ExprIR]], FrozenSchema]: @@ -770,15 +765,13 @@ def _project( named_irs = expr_expansion.into_named_irs(irs, output_names) return schema_frozen.project(named_irs, context) - def select(self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any) -> Self: + def select(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.SELECT ) return self._from_compliant(self._compliant.select(named_irs)) - def with_columns( - self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: t.Any - ) -> Self: + def with_columns(self, *exprs: OneOrIterable[IntoExpr], **named_exprs: Any) -> Self: named_irs, schema_projected = self._project( exprs, named_exprs, ExprContext.WITH_COLUMNS ) @@ -786,10 +779,10 @@ def with_columns( def sort( self, - by: str | Iterable[str], + by: OneOrIterable[str], *more_by: str, - descending: bool | Sequence[bool] = False, - nulls_last: bool | Sequence[bool] = False, + descending: OneOrIterable[bool] = False, + nulls_last: OneOrIterable[bool] = False, ) -> Self: sort, opts = _parse_sort_by( by, *more_by, descending=descending, nulls_last=nulls_last @@ -802,7 +795,7 @@ def sort( class DataFrame(BaseFrame[NativeDataFrameT], Generic[NativeDataFrameT, NativeSeriesT]): - _compliant: CompliantDataFrame[t.Any, NativeDataFrameT, NativeSeriesT] + _compliant: CompliantDataFrame[Any, NativeDataFrameT, NativeSeriesT] @property def _series(self) -> type[Series[NativeSeriesT]]: @@ -812,7 +805,7 @@ def _series(self) -> type[Series[NativeSeriesT]]: @classmethod def from_native( # type: ignore[override] cls, native: NativeFrame, / - ) -> DataFrame[pa.Table, pa.ChunkedArray[t.Any]]: + ) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: if is_pyarrow_table(native): from narwhals._plan.arrow.dataframe import ArrowDataFrame @@ -820,22 +813,19 @@ def from_native( # type: ignore[override] raise NotImplementedError(type(native)) - @t.overload + @overload def to_dict( - self, *, as_series: t.Literal[True] = ... + self, *, as_series: Literal[True] = ... ) -> dict[str, Series[NativeSeriesT]]: ... - - @t.overload - def to_dict(self, *, as_series: t.Literal[False]) -> dict[str, list[t.Any]]: ... - - @t.overload + @overload + def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... + @overload def to_dict( self, *, as_series: bool - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: ... - + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool = True - ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[t.Any]]: + ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: if as_series: return { key: self._series._from_compliant(value) @@ -849,7 +839,7 @@ def __len__(self) -> int: class Series(Generic[NativeSeriesT]): _compliant: CompliantSeries[NativeSeriesT] - _version: t.ClassVar[Version] = Version.MAIN + _version: ClassVar[Version] = Version.MAIN @property def version(self) -> Version: @@ -867,7 +857,7 @@ def name(self) -> str: @classmethod def from_native( cls, native: NativeSeries, name: str = "", / - ) -> Series[pa.ChunkedArray[t.Any]]: + ) -> Series[pa.ChunkedArray[Any]]: if is_pyarrow_chunked_array(native): from narwhals._plan.arrow.series import ArrowSeries @@ -886,12 +876,12 @@ def _from_compliant(cls, compliant: CompliantSeries[NativeSeriesT], /) -> Self: def to_native(self) -> NativeSeriesT: return self._compliant.native - def to_list(self) -> list[t.Any]: + def to_list(self) -> list[Any]: return self._compliant.to_list() - def __iter__(self) -> t.Iterator[t.Any]: + def __iter__(self) -> Iterator[Any]: yield from self.to_native() class SeriesV1(Series[NativeSeriesT]): - _version: t.ClassVar[Version] = Version.V1 + _version: ClassVar[Version] = Version.V1 diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 610a9e80a1..7d35237f16 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -6,23 +6,20 @@ # - Literal import typing as t -from narwhals._plan import common from narwhals._plan.aggregation import AggExpr, OrderableAggExpr from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias +from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT, LeftSelectorT, LeftT, - LeftT2, LiteralT, - MapIR, OperatorT, RangeT, RightSelectorT, RightT, - RightT2, RollingT, SelectorOperatorT, SelectorT, @@ -37,6 +34,7 @@ from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -66,7 +64,7 @@ "SelectorIR", "Sort", "SortBy", - "Ternary", + "TernaryExpr", "WindowExpr", "col", ] @@ -88,7 +86,7 @@ def index_columns(*indices: int) -> IndexColumns: return IndexColumns(indices=indices) -class Alias(ExprIR): +class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "name") expr: ExprIR name: str @@ -100,41 +98,18 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.alias({self.name!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Column(ExprIR): +class Column(ExprIR, config=ExprIROptions.namespaced("col")): __slots__ = ("name",) name: str def __repr__(self) -> str: return f"col({self.name!r})" - def with_name(self, name: str, /) -> Column: - return common.replace(self, name=name) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - - -class _ColumnSelection(ExprIR): +class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): """Nodes which can resolve to `Column`(s) with a `Schema`.""" - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class Columns(_ColumnSelection): __slots__ = ("names",) @@ -165,7 +140,7 @@ def __repr__(self) -> str: return "all()" -class Exclude(_ColumnSelection): +class Exclude(_ColumnSelection, child=("expr",)): __slots__ = ("expr", "names") expr: ExprIR """Default is `all()`.""" @@ -180,22 +155,8 @@ def from_names(expr: ExprIR, *names: str | t.Iterable[str]) -> Exclude: def __repr__(self) -> str: return f"{self.expr!r}.exclude({list(self.names)!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Literal(ExprIR, t.Generic[LiteralT]): +class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" __slots__ = ("value",) @@ -219,9 +180,6 @@ def __repr__(self) -> str: def unwrap(self) -> LiteralT: return self.value.unwrap() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class _BinaryOp(ExprIR, t.Generic[LeftT, OperatorT, RightT]): __slots__ = ("left", "op", "right") @@ -238,40 +196,17 @@ def __repr__(self) -> str: class BinaryExpr( - _BinaryOp[LeftT, OperatorT, RightT], t.Generic[LeftT, OperatorT, RightT] + _BinaryOp[LeftT, OperatorT, RightT], + t.Generic[LeftT, OperatorT, RightT], + child=("left", "right"), ): """Application of two exprs via an `Operator`.""" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.left.iter_left() - yield from self.right.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.right.iter_right() - yield from self.left.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.left.iter_output_name() - def with_left(self, left: LeftT2, /) -> BinaryExpr[LeftT2, OperatorT, RightT]: - changed = common.replace(self, left=left) - return t.cast("BinaryExpr[LeftT2, OperatorT, RightT]", changed) - - def with_right(self, right: RightT2, /) -> BinaryExpr[LeftT, OperatorT, RightT2]: - changed = common.replace(self, right=right) - return t.cast("BinaryExpr[LeftT, OperatorT, RightT2]", changed) - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function( - self.with_left(self.left.map_ir(function)).with_right( - self.right.map_ir(function) - ) - ) - -class Cast(ExprIR): +class Cast(ExprIR, child=("expr",)): __slots__ = ("expr", "dtype") # noqa: RUF023 expr: ExprIR dtype: DType @@ -283,25 +218,11 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.cast({self.dtype!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class Sort(ExprIR): +class Sort(ExprIR, child=("expr",)): __slots__ = ("expr", "options") expr: ExprIR options: SortOptions @@ -314,25 +235,11 @@ def __repr__(self) -> str: direction = "desc" if self.options.descending else "asc" return f"{self.expr!r}.sort({direction})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class SortBy(ExprIR): +class SortBy(ExprIR, child=("expr", "by")): """https://github.com/narwhals-dev/narwhals/issues/2534.""" __slots__ = ("expr", "by", "options") # noqa: RUF023 @@ -347,33 +254,11 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.sort_by(by={self.by!r}, options={self.options!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - by = (ir.map_ir(function) for ir in self.by) - return function(self.with_expr(self.expr.map_ir(function)).with_by(by)) - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - - def with_by(self, by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, by=collect(by)) - - -class FunctionExpr(ExprIR, t.Generic[FunctionT]): +class FunctionExpr(ExprIR, t.Generic[FunctionT], child=("input",)): """**Representing `Expr::Function`**. https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 @@ -392,15 +277,6 @@ class FunctionExpr(ExprIR, t.Generic[FunctionT]): def is_scalar(self) -> bool: return self.function.is_scalar - def with_options(self, options: FunctionOptions, /) -> Self: - return common.replace(self, options=self.options.with_flags(options.flags)) - - def with_input(self, input: t.Iterable[ExprIR], /) -> Self: # noqa: A002 - return common.replace(self, input=collect(input)) - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_input(ir.map_ir(function) for ir in self.input)) - def __repr__(self) -> str: if self.input: first = self.input[0] @@ -409,16 +285,6 @@ def __repr__(self) -> str: return f"{first!r}.{self.function!r}()" return f"{self.function!r}()" - def iter_left(self) -> t.Iterator[ExprIR]: - for e in self.input: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.input): - yield from e.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: """When we have multiple inputs, we want the name of the left-most expression. @@ -447,13 +313,25 @@ def __init__( raise function_expr_invalid_operation_error(function, parent) super().__init__(**dict(input=input, function=function, options=options, **kwds)) + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RollingExpr(FunctionExpr[RollingT]): ... -class AnonymousExpr(FunctionExpr["MapBatches"]): +class AnonymousExpr( + FunctionExpr["MapBatches"], config=ExprIROptions.renamed("map_batches") +): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RangeExpr(FunctionExpr[RangeT]): """E.g. `int_range(...)`. @@ -484,7 +362,7 @@ def __repr__(self) -> str: return f"{self.function!r}({list(self.input)!r})" -class Filter(ExprIR): +class Filter(ExprIR, child=("expr", "by")): __slots__ = ("expr", "by") # noqa: RUF023 expr: ExprIR by: ExprIR @@ -496,26 +374,13 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.filter({self.by!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - yield from self.by.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.by.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - expr, by = self.expr, self.by - changed = common.replace(self, expr=expr.map_ir(function), by=by.map_ir(function)) - return function(changed) - -class WindowExpr(ExprIR): +class WindowExpr( + ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") +): """A fully specified `.over()`, that occurred after another expression. Related: @@ -533,35 +398,15 @@ class WindowExpr(ExprIR): def __repr__(self) -> str: return f"{self.expr!r}.over({list(self.partition_by)!r})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.partition_by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.expr.iter_output_name() - def map_ir(self, function: MapIR, /) -> ExprIR: - over = self.with_expr(self.expr.map_ir(function)).with_partition_by( - ir.map_ir(function) for ir in self.partition_by - ) - return function(over) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - - def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, partition_by=collect(partition_by)) - -class OrderedWindowExpr(WindowExpr): +class OrderedWindowExpr( + WindowExpr, + child=("expr", "partition_by", "order_by"), + config=ExprIROptions.renamed("over_ordered"), +): __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 expr: ExprIR partition_by: Seq[ExprIR] @@ -577,41 +422,18 @@ def __repr__(self) -> str: args = f"partition_by={list(self.partition_by)!r}, order_by={list(order)!r}" return f"{self.expr!r}.over({args})" - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.expr.iter_left() - for e in self.partition_by: - yield from e.iter_left() - for e in self.order_by: - yield from e.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - for e in reversed(self.order_by): - yield from e.iter_right() - for e in reversed(self.partition_by): - yield from e.iter_right() - yield from self.expr.iter_right() - def iter_root_names(self) -> t.Iterator[ExprIR]: # NOTE: `order_by` is never considered in `polars` # To match that behavior for `root_names` - but still expand in all other cases # - this little escape hatch exists # https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/plans/iterator.rs#L76-L86 - yield from super().iter_left() - - def map_ir(self, function: MapIR, /) -> ExprIR: - over = self.with_expr(self.expr.map_ir(function)).with_partition_by( - ir.map_ir(function) for ir in self.partition_by - ) - over = over.with_order_by(ir.map_ir(function) for ir in self.order_by) - return function(over) - - def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: - return common.replace(self, order_by=collect(order_by)) + yield from self.expr.iter_left() + for e in self.partition_by: + yield from e.iter_left() + yield self -class Len(ExprIR): +class Len(ExprIR, config=ExprIROptions.namespaced()): @property def is_scalar(self) -> bool: return True @@ -623,9 +445,6 @@ def name(self) -> str: def __repr__(self) -> str: return "len()" - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class RootSelector(SelectorIR): """A single selector expression.""" @@ -639,9 +458,6 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return self.selector.matches_column(name, dtype) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class BinarySelector( _BinaryOp[LeftSelectorT, SelectorOperatorT, RightSelectorT], @@ -655,9 +471,6 @@ def matches_column(self, name: str, dtype: DType) -> bool: right = self.right.matches_column(name, dtype) return bool(self.op(left, right)) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - class InvertSelector(SelectorIR, t.Generic[SelectorT]): __slots__ = ("selector",) @@ -669,14 +482,11 @@ def __repr__(self) -> str: def matches_column(self, name: str, dtype: DType) -> bool: return not self.selector.matches_column(name, dtype) - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self) - -class Ternary(ExprIR): +class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): """When-Then-Otherwise.""" - __slots__ = ("predicate", "truthy", "falsy") # noqa: RUF023 + __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 predicate: ExprIR truthy: ExprIR falsy: ExprIR @@ -690,24 +500,5 @@ def __repr__(self) -> str: f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})" ) - def iter_left(self) -> t.Iterator[ExprIR]: - yield from self.truthy.iter_left() - yield from self.falsy.iter_left() - yield from self.predicate.iter_left() - yield self - - def iter_right(self) -> t.Iterator[ExprIR]: - yield self - yield from self.predicate.iter_right() - yield from self.falsy.iter_right() - yield from self.truthy.iter_right() - def iter_output_name(self) -> t.Iterator[ExprIR]: yield from self.truthy.iter_output_name() - - def map_ir(self, function: MapIR, /) -> ExprIR: - predicate = self.predicate.map_ir(function) - truthy = self.truthy.map_ir(function) - falsy = self.falsy.map_ir(function) - changed = common.replace(self, predicate=predicate, truthy=truthy, falsy=falsy) - return function(changed) diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d7ab345f81..2fdf92ef7d 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -40,17 +40,12 @@ from collections import deque from functools import lru_cache -from itertools import chain from typing import TYPE_CHECKING -from narwhals._plan import common -from narwhals._plan.common import ( - ExprIR, - Immutable, - NamedIR, - SelectorIR, - is_horizontal_reduction, -) +from narwhals._plan import common, meta +from narwhals._plan._guards import is_horizontal_reduction +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import ExprIR, NamedIR, SelectorIR from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, @@ -79,11 +74,10 @@ from narwhals.exceptions import ComputeError, InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterable, Iterator, Sequence from typing_extensions import TypeAlias - from narwhals._plan.dummy import Expr from narwhals._plan.typing import Seq from narwhals.dtypes import DType @@ -153,10 +147,6 @@ def from_ir(ir: ExprIR, /) -> ExpansionFlags: has_exclude=has_exclude, ) - @classmethod - def from_expr(cls, expr: Expr, /) -> ExpansionFlags: - return cls.from_ir(expr._ir) - def with_multiple_columns(self) -> ExpansionFlags: return common.replace(self, multiple_columns=True) @@ -194,7 +184,7 @@ def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: """Raise an appropriate error if we can't materialize.""" output_names = _ensure_output_names_unique(exprs) - root_names = _root_names_unique(exprs) + root_names = meta.root_names_unique(exprs) if not (set(schema.names).issuperset(root_names)): raise column_not_found_error(root_names, schema) return output_names @@ -207,19 +197,11 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: return names -def _root_names_unique(exprs: Seq[ExprIR]) -> set[str]: - from narwhals._plan.meta import _expr_to_leaf_column_names_iter - - it = chain.from_iterable(_expr_to_leaf_column_names_iter(expr) for expr in exprs) - return set(it) - - def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: if is_horizontal_reduction(child): - return child.with_input( - rewrite_projections(child.input, keys=(), schema=schema) - ) + rewrites = rewrite_projections(child.input, keys=(), schema=schema) + return common.replace(child, input=rewrites) return child return origin.map_ir(fn) @@ -245,18 +227,7 @@ def is_index_in_range(index: int, n_fields: int) -> bool: def remove_alias(origin: ExprIR, /) -> ExprIR: def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Alias): - return child.expr - return child - - return origin.map_ir(fn) - - -def remove_exclude(origin: ExprIR, /) -> ExprIR: - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, Exclude): - return child.expr - return child + return child.expr if isinstance(child, Alias) else child return origin.map_ir(fn) @@ -269,20 +240,14 @@ def replace_with_column( def fn(child: ExprIR, /) -> ExprIR: if isinstance(child, tp): return col(name) - if isinstance(child, Exclude): - return child.expr - return child + return child.expr if isinstance(child, Exclude) else child return origin.map_ir(fn) def replace_selector(ir: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: - """Fully diverging from `polars`, we'll see how that goes.""" - def fn(child: ExprIR, /) -> ExprIR: - if isinstance(child, SelectorIR): - return expand_selector(child, schema=schema) - return child + return expand_selector(child, schema) if isinstance(child, SelectorIR) else child return ir.map_ir(fn) @@ -299,7 +264,7 @@ def selector_matches_column(selector: SelectorIR, name: str, dtype: DType, /) -> @lru_cache(maxsize=100) -def expand_selector(selector: SelectorIR, *, schema: FrozenSchema) -> Columns: +def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: """Expand `selector` into `Columns`, within the context of `schema`.""" matches = selector_matches_column return cols(*(k for k, v in schema.items() if matches(selector, k, v))) @@ -319,64 +284,46 @@ def rewrite_projections( if flags.has_selector: expanded = replace_selector(expanded, schema=schema) flags = flags.with_multiple_columns() - result.extend( - replace_and_add_to_results( - expanded, keys=keys, col_names=schema.names, flags=flags - ) - ) + result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags)) return tuple(result) -def replace_and_add_to_results( +def iter_replace( origin: ExprIR, /, keys: GroupByKeys, *, col_names: FrozenColumns, flags: ExpansionFlags, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() +) -> Iterator[ExprIR]: if flags.has_nth: origin = replace_nth(origin, col_names) if flags.expands: it = (e for e in origin.iter_left() if isinstance(e, (Columns, IndexColumns))) if e := next(it, None): if isinstance(e, Columns): - exclude = prepare_excluded( - origin, keys=keys, has_exclude=flags.has_exclude - ) - result.extend(expand_columns(origin, e, exclude=exclude)) + if not _all_columns_match(origin, e): + msg = "expanding more than one `col` is not allowed" + raise ComputeError(msg) + names: Iterable[str] = e.names else: - exclude = prepare_excluded( - origin, keys=keys, has_exclude=flags.has_exclude - ) - result.extend( - expand_indices(origin, e, col_names=col_names, exclude=exclude) - ) + names = _iter_index_names(e, col_names) + exclude = prepare_excluded(origin, keys, flags) + yield from expand_column_selection(origin, type(e), names, exclude) elif flags.has_wildcard: - exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - result.extend(replace_wildcard(origin, col_names=col_names, exclude=exclude)) + exclude = prepare_excluded(origin, keys, flags) + yield from expand_column_selection(origin, All, col_names, exclude) else: - exclude = prepare_excluded(origin, keys=keys, has_exclude=flags.has_exclude) - expanded = rewrite_special_aliases(origin) - result.append(expanded) - return tuple(result) - - -def _iter_exclude_names(origin: ExprIR, /) -> Iterator[str]: - """Yield all excluded names in `origin`.""" - for e in origin.iter_left(): - if isinstance(e, Exclude): - yield from e.names + yield rewrite_special_aliases(origin) def prepare_excluded( - origin: ExprIR, /, keys: GroupByKeys, *, has_exclude: bool + origin: ExprIR, keys: GroupByKeys, flags: ExpansionFlags, / ) -> Excluded: """Huge simplification of https://github.com/pola-rs/polars/blob/0fa7141ce718c6f0a4d6ae46865c867b177a59ed/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L484-L555.""" exclude: set[str] = set() - if has_exclude: - exclude.update(_iter_exclude_names(origin)) + if flags.has_exclude: + exclude.update(*(e.names for e in origin.iter_left() if isinstance(e, Exclude))) for group_by_key in keys: if name := group_by_key.meta.output_name(raise_if_undetermined=False): exclude.add(name) @@ -388,52 +335,20 @@ def _all_columns_match(origin: ExprIR, /, columns: Columns) -> bool: return all(it) -def expand_columns( - origin: ExprIR, /, columns: Columns, *, exclude: Excluded -) -> Seq[ExprIR]: - if not _all_columns_match(origin, columns): - msg = "expanding more than one `col` is not allowed" - raise ComputeError(msg) - result: deque[ExprIR] = deque() - for name in columns.names: - if name not in exclude: - expanded = replace_with_column(origin, Columns, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) - - -def expand_indices( - origin: ExprIR, - /, - indices: IndexColumns, - *, - col_names: FrozenColumns, - exclude: Excluded, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - n_fields = len(col_names) +def _iter_index_names(indices: IndexColumns, names: FrozenColumns, /) -> Iterator[str]: + n_fields = len(names) for index in indices.indices: if not is_index_in_range(index, n_fields): - raise column_index_error(index, col_names) - name = col_names[index] - if name not in exclude: - expanded = replace_with_column(origin, IndexColumns, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) + raise column_index_error(index, names) + yield names[index] -def replace_wildcard( - origin: ExprIR, /, *, col_names: FrozenColumns, exclude: Excluded -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - for name in col_names: +def expand_column_selection( + origin: ExprIR, tp: type[_ColumnSelection], /, names: Iterable[str], exclude: Excluded +) -> Iterator[ExprIR]: + for name in names: if name not in exclude: - expanded = replace_with_column(origin, All, name) - expanded = rewrite_special_aliases(expanded) - result.append(expanded) - return tuple(result) + yield rewrite_special_aliases(replace_with_column(origin, tp, name)) def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: @@ -444,8 +359,6 @@ def rewrite_special_aliases(origin: ExprIR, /) -> ExprIR: - Expanding all selections into `Column` - Dealing with `FunctionExpr.input` """ - from narwhals._plan import meta - if meta.has_expr_ir(origin, KeepName, RenameAlias): if isinstance(origin, KeepName): parent = origin.expr diff --git a/narwhals/_plan/expr_parsing.py b/narwhals/_plan/expr_parsing.py index b1ef22fbc9..1e450f2307 100644 --- a/narwhals/_plan/expr_parsing.py +++ b/narwhals/_plan/expr_parsing.py @@ -6,7 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING, TypeVar -from narwhals._plan.common import is_expr, is_iterable_reject +from narwhals._plan._guards import is_expr, is_iterable_reject from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, @@ -22,7 +22,7 @@ from typing_extensions import TypeAlias, TypeIs from narwhals._plan.common import ExprIR - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq from narwhals.typing import IntoDType T = TypeVar("T") @@ -100,7 +100,7 @@ def parse_into_expr_ir( def parse_into_seq_of_expr_ir( - first_input: IntoExpr | Iterable[IntoExpr] = (), + first_input: OneOrIterable[IntoExpr] = (), *more_inputs: IntoExpr | _RaisesInvalidIntoExprError, **named_inputs: IntoExpr, ) -> Seq[ExprIR]: @@ -109,7 +109,7 @@ def parse_into_seq_of_expr_ir( def parse_predicates_constraints_into_expr_ir( - first_predicate: IntoExprColumn | Iterable[IntoExprColumn] = (), + first_predicate: OneOrIterable[IntoExprColumn] = (), *more_predicates: IntoExprColumn | _RaisesInvalidIntoExprError, **constraints: IntoExpr, ) -> ExprIR: @@ -125,9 +125,7 @@ def parse_predicates_constraints_into_expr_ir( def _parse_into_iter_expr_ir( - first_input: IntoExpr | Iterable[IntoExpr], - *more_inputs: IntoExpr, - **named_inputs: IntoExpr, + first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr, **named_inputs: IntoExpr ) -> Iterator[ExprIR]: if not _is_empty_sequence(first_input): # NOTE: These need to be separated to introduce an intersection type diff --git a/narwhals/_plan/expr_rewrites.py b/narwhals/_plan/expr_rewrites.py index 705dec3371..597e8afc21 100644 --- a/narwhals/_plan/expr_rewrites.py +++ b/narwhals/_plan/expr_rewrites.py @@ -5,14 +5,13 @@ from typing import TYPE_CHECKING from narwhals._plan import expr_parsing as parse -from narwhals._plan.common import ( - NamedIR, +from narwhals._plan._guards import ( is_aggregation, is_binary_expr, is_function_expr, is_window_expr, - map_ir, ) +from narwhals._plan.common import NamedIR, map_ir, replace from narwhals._plan.expr_expansion import into_named_irs, prepare_projection if TYPE_CHECKING: @@ -59,7 +58,7 @@ def rewrite_elementwise_over(window: ExprIR, /) -> ExprIR: ): func = window.expr parent, *args = func.input - return func.with_input((window.with_expr(parent), *args)) + return replace(func, input=(replace(window, expr=parent), *args)) return window @@ -84,6 +83,5 @@ def rewrite_binary_agg_over(window: ExprIR, /) -> ExprIR: and (is_aggregation(window.expr.right)) ): binary_expr = window.expr - rhs = window.expr.right - return binary_expr.with_right(window.with_expr(rhs)) + return replace(binary_expr, right=replace(window, expr=binary_expr.right)) return window diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 4c80849a89..570d75b4d0 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Function +from narwhals._plan.common import Function, HorizontalFunction from narwhals._plan.exceptions import hist_bins_monotonic_error from narwhals._plan.options import FunctionFlags, FunctionOptions @@ -21,10 +21,48 @@ from narwhals.typing import FillNullStrategy -class Abs(Function, options=FunctionOptions.elementwise): ... +class CumAgg(Function, options=FunctionOptions.length_preserving): + __slots__ = ("reverse",) + reverse: bool + + +class RollingWindow(Function, options=FunctionOptions.length_preserving): + __slots__ = ("options",) + options: RollingOptionsFixedWindow + def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: + from narwhals._plan.expr import RollingExpr + + options = self.function_options + return RollingExpr(input=inputs, function=self, options=options) -class Hist(Function, options=FunctionOptions.groupwise): + +# fmt: off +class Abs(Function, options=FunctionOptions.elementwise): ... +class NullCount(Function, options=FunctionOptions.aggregation): ... +class Exp(Function, options=FunctionOptions.elementwise): ... +class Sqrt(Function, options=FunctionOptions.elementwise): ... +class DropNulls(Function, options=FunctionOptions.row_separable): ... +class Mode(Function): ... +class Skew(Function, options=FunctionOptions.aggregation): ... +class Clip(Function, options=FunctionOptions.elementwise): ... +class CumCount(CumAgg): ... +class CumMin(CumAgg): ... +class CumMax(CumAgg): ... +class CumProd(CumAgg): ... +class CumSum(CumAgg): ... +class RollingSum(RollingWindow): ... +class RollingMean(RollingWindow): ... +class RollingVar(RollingWindow): ... +class RollingStd(RollingWindow): ... +class Diff(Function, options=FunctionOptions.length_preserving): ... +class Unique(Function): ... +class SumHorizontal(HorizontalFunction): ... +class MinHorizontal(HorizontalFunction): ... +class MaxHorizontal(HorizontalFunction): ... +class MeanHorizontal(HorizontalFunction): ... +# fmt: on +class Hist(Function): """Only supported for `Series` so far.""" __slots__ = ("include_breakpoint",) @@ -55,17 +93,11 @@ def __init__(self, *, bin_count: int = 10, include_breakpoint: bool = True) -> N object.__setattr__(self, "include_breakpoint", include_breakpoint) -class NullCount(Function, options=FunctionOptions.aggregation): ... - - class Log(Function, options=FunctionOptions.elementwise): __slots__ = ("base",) base: float -class Exp(Function, options=FunctionOptions.elementwise): ... - - class Pow(Function, options=FunctionOptions.elementwise): """N-ary (base, exponent).""" @@ -74,9 +106,6 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR]: return base, exponent -class Sqrt(Function, options=FunctionOptions.elementwise): ... - - class Kurtosis(Function, options=FunctionOptions.aggregation): __slots__ = ("bias", "fisher") fisher: bool @@ -102,89 +131,16 @@ class Shift(Function, options=FunctionOptions.length_preserving): n: int -class DropNulls(Function, options=FunctionOptions.row_separable): ... - - -class Mode(Function, options=FunctionOptions.groupwise): ... - - -class Skew(Function, options=FunctionOptions.aggregation): ... - - -class Rank(Function, options=FunctionOptions.groupwise): +class Rank(Function): __slots__ = ("options",) options: RankOptions -class Clip(Function, options=FunctionOptions.elementwise): ... - - -class CumAgg(Function, options=FunctionOptions.length_preserving): - __slots__ = ("reverse",) - reverse: bool - - -class RollingWindow(Function, options=FunctionOptions.length_preserving): - __slots__ = ("options",) - options: RollingOptionsFixedWindow - - def to_function_expr(self, *inputs: ExprIR) -> RollingExpr[Self]: - from narwhals._plan.expr import RollingExpr - - options = self.function_options - return RollingExpr(input=inputs, function=self, options=options) - - -class CumCount(CumAgg): ... - - -class CumMin(CumAgg): ... - - -class CumMax(CumAgg): ... - - -class CumProd(CumAgg): ... - - -class CumSum(CumAgg): ... - - -class RollingSum(RollingWindow): ... - - -class RollingMean(RollingWindow): ... - - -class RollingVar(RollingWindow): ... - - -class RollingStd(RollingWindow): ... - - -class Diff(Function, options=FunctionOptions.length_preserving): ... - - -class Unique(Function, options=FunctionOptions.groupwise): ... - - class Round(Function, options=FunctionOptions.elementwise): __slots__ = ("decimals",) decimals: int -class SumHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MinHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MaxHorizontal(Function, options=FunctionOptions.horizontal): ... - - -class MeanHorizontal(Function, options=FunctionOptions.horizontal): ... - - class EwmMean(Function, options=FunctionOptions.length_preserving): __slots__ = ("options",) options: EWMOptions @@ -197,7 +153,7 @@ class ReplaceStrict(Function, options=FunctionOptions.elementwise): return_dtype: DType | None -class GatherEvery(Function, options=FunctionOptions.groupwise): +class GatherEvery(Function): __slots__ = ("n", "offset") n: int offset: int diff --git a/narwhals/_plan/lists.py b/narwhals/_plan/lists.py index 046db5615d..f4a45f217f 100644 --- a/narwhals/_plan/lists.py +++ b/narwhals/_plan/lists.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace from narwhals._plan.options import FunctionOptions @@ -9,15 +9,12 @@ from narwhals._plan.dummy import Expr +# fmt: off class ListFunction(Function, accessor="list"): ... - - class Len(ListFunction, options=FunctionOptions.elementwise): ... - - +# fmt: on class IRListNamespace(IRNamespace): - def len(self) -> Len: - return Len() + len: ClassVar = Len class ExprListNamespace(ExprNamespace[IRListNamespace]): diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index e0dba305fa..94f5a9a5b4 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any, Generic -from narwhals._plan.common import Immutable +from narwhals._plan._guards import is_literal +from narwhals._plan._immutable import Immutable from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: @@ -74,31 +75,7 @@ def unwrap(self) -> Series[NativeSeriesT]: return self.value -def _is_scalar( - obj: ScalarLiteral[NonNestedLiteralT] | Any, -) -> TypeIs[ScalarLiteral[NonNestedLiteralT]]: - return isinstance(obj, ScalarLiteral) - - -def _is_series( - obj: SeriesLiteral[NativeSeriesT] | Any, -) -> TypeIs[SeriesLiteral[NativeSeriesT]]: - return isinstance(obj, SeriesLiteral) - - -def is_literal(obj: Literal[LiteralT] | Any) -> TypeIs[Literal[LiteralT]]: - from narwhals._plan.expr import Literal - - return isinstance(obj, Literal) - - def is_literal_scalar( obj: Literal[NonNestedLiteralT] | Any, ) -> TypeIs[Literal[NonNestedLiteralT]]: - return is_literal(obj) and _is_scalar(obj.value) - - -def is_literal_series( - obj: Literal[Series[NativeSeriesT]] | Any, -) -> TypeIs[Literal[Series[NativeSeriesT]]]: - return is_literal(obj) and _is_series(obj.value) + return is_literal(obj) and isinstance(obj.value, ScalarLiteral) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index 1a5cbb8eac..ce78165c00 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -7,6 +7,7 @@ from __future__ import annotations from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING, Literal, overload from narwhals._plan.common import IRNamespace @@ -14,7 +15,7 @@ from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from typing_extensions import TypeIs @@ -74,16 +75,7 @@ def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: def root_names(self) -> list[str]: """Get the root column names.""" - return _expr_to_leaf_column_names(self._ir) - - -def _expr_to_leaf_column_names(ir: ExprIR) -> list[str]: - """After a lot of indirection, [root_names] resolves [here]. - - [root_names]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L27-L30 - [here]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/utils.rs#L171-L195 - """ - return list(_expr_to_leaf_column_names_iter(ir)) + return list(_expr_to_leaf_column_names_iter(self._ir)) def _expr_to_leaf_column_names_iter(ir: ExprIR) -> Iterator[str]: @@ -121,6 +113,10 @@ def _expr_to_leaf_column_name(ir: ExprIR) -> str | ComputeError: return ComputeError(msg) +def root_names_unique(irs: Iterable[ExprIR], /) -> set[str]: + return set(chain.from_iterable(_expr_to_leaf_column_names_iter(e) for e in irs)) + + @lru_cache(maxsize=32) def _expr_output_name(ir: ExprIR) -> str | ComputeError: from narwhals._plan import expr @@ -186,26 +182,22 @@ def is_column(ir: ExprIR) -> TypeIs[Column]: def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr - from narwhals._plan.literal import ScalarLiteral - - if isinstance(ir, expr.Literal): - return True - if isinstance(ir, expr.Alias): - return allow_aliasing - if isinstance(ir, expr.Cast): - return ( - isinstance(ir.expr, expr.Literal) - and isinstance(ir.expr, ScalarLiteral) + from narwhals._plan.literal import is_literal_scalar + + return ( + isinstance(ir, expr.Literal) + or (allow_aliasing and isinstance(ir, expr.Alias)) + or ( + isinstance(ir, expr.Cast) + and is_literal_scalar(ir.expr) and isinstance(ir.expr.dtype, Version.MAIN.dtypes.Datetime) ) - return False + ) def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool: from narwhals._plan import expr - if isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)): - return True - if isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)): - return allow_aliasing - return False + return isinstance(ir, (expr.Column, expr._ColumnSelection, expr.SelectorIR)) or ( + allow_aliasing and isinstance(ir, (expr.Alias, expr.KeepName, expr.RenameAlias)) + ) diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 7c695599bc..4147f20450 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -3,21 +3,17 @@ from typing import TYPE_CHECKING from narwhals._plan import common -from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace +from narwhals._plan._immutable import Immutable +from narwhals._plan.options import ExprIROptions if TYPE_CHECKING: - from collections.abc import Iterator - - from typing_extensions import Self - from narwhals._compliant.typing import AliasName from narwhals._plan.dummy import Expr - from narwhals._plan.typing import MapIR -class KeepName(ExprIR): +class KeepName(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr",) - expr: ExprIR + expr: common.ExprIR @property def is_scalar(self) -> bool: @@ -26,24 +22,10 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f"{self.expr!r}.name.keep()" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - -class RenameAlias(ExprIR): +class RenameAlias(common.ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "function") - expr: ExprIR + expr: common.ExprIR function: AliasName @property @@ -53,20 +35,6 @@ def is_scalar(self) -> bool: def __repr__(self) -> str: return f".rename_alias({self.expr!r})" - def iter_left(self) -> Iterator[ExprIR]: - yield from self.expr.iter_left() - yield self - - def iter_right(self) -> Iterator[ExprIR]: - yield self - yield from self.expr.iter_right() - - def map_ir(self, function: MapIR, /) -> ExprIR: - return function(self.with_expr(self.expr.map_ir(function))) - - def with_expr(self, expr: ExprIR, /) -> Self: - return common.replace(self, expr=expr) - class Prefix(Immutable): __slots__ = ("prefix",) @@ -84,7 +52,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class IRNameNamespace(IRNamespace): +class IRNameNamespace(common.IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -104,7 +72,7 @@ def to_uppercase(self) -> RenameAlias: return self.map(str.upper) -class ExprNameNamespace(ExprNamespace[IRNameNamespace]): +class ExprNameNamespace(common.ExprNamespace[IRNameNamespace]): @property def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 09d072e7bd..78b33b042f 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,15 +1,15 @@ from __future__ import annotations -import operator +import operator as op from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr +from narwhals._plan._guards import is_function_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( binary_expr_length_changing_error, binary_expr_multi_output_error, binary_expr_shape_error, ) -from narwhals._plan.expr import BinarySelector if TYPE_CHECKING: from typing import Any, ClassVar @@ -28,30 +28,19 @@ class Operator(Immutable): - _op: ClassVar[OperatorFn] + _func: ClassVar[OperatorFn] + _symbol: ClassVar[str] def __repr__(self) -> str: - tp = type(self) - if tp in {Operator, SelectorOperator}: - return tp.__name__ - m = { - Eq: "==", - NotEq: "!=", - Lt: "<", - LtEq: "<=", - Gt: ">", - GtEq: ">=", - Add: "+", - Sub: "-", - Multiply: "*", - TrueDivide: "/", - FloorDivide: "//", - Modulus: "%", - And: "&", - Or: "|", - ExclusiveOr: "^", - } - return m[tp] + return self._symbol + + def __init_subclass__( + cls, *args: Any, func: OperatorFn | None, symbol: str = "", **kwds: Any + ) -> None: + super().__init_subclass__(*args, **kwds) + if func: + cls._func = func + cls._symbol = symbol or cls.__name__ def to_binary_expr( self, left: LeftT, right: RightT, / @@ -72,16 +61,14 @@ def to_binary_expr( def __call__(self, lhs: Any, rhs: Any) -> Any: """Apply binary operator to `left`, `right` operands.""" - return self.__class__._op(lhs, rhs) + return self.__class__._func(lhs, rhs) def _is_filtration(ir: ExprIR) -> bool: - if not ir.is_scalar and is_function_expr(ir): - return not ir.options.is_elementwise() - return False + return not ir.is_scalar and is_function_expr(ir) and not ir.options.is_elementwise() -class SelectorOperator(Operator): +class SelectorOperator(Operator, func=None): """Operators that can *also* be used in selectors.""" def to_binary_selector( @@ -92,61 +79,20 @@ def to_binary_selector( return BinarySelector(left=left, op=self, right=right) -class Eq(Operator): - _op = operator.eq - - -class NotEq(Operator): - _op = operator.ne - - -class Lt(Operator): - _op = operator.le - - -class LtEq(Operator): - _op = operator.lt - - -class Gt(Operator): - _op = operator.gt - - -class GtEq(Operator): - _op = operator.ge - - -class Add(Operator): - _op = operator.add - - -class Sub(SelectorOperator): - _op = operator.sub - - -class Multiply(Operator): - _op = operator.mul - - -class TrueDivide(Operator): - _op = operator.truediv - - -class FloorDivide(Operator): - _op = operator.floordiv - - -class Modulus(Operator): - _op = operator.mod - - -class And(SelectorOperator): - _op = operator.and_ - - -class Or(SelectorOperator): - _op = operator.or_ - - -class ExclusiveOr(SelectorOperator): - _op = operator.xor +# fmt: off +class Eq(Operator, func=op.eq, symbol="=="): ... +class NotEq(Operator, func=op.ne, symbol="!="): ... +class Lt(Operator, func=op.le, symbol="<"): ... +class LtEq(Operator, func=op.lt, symbol="<="): ... +class Gt(Operator, func=op.gt, symbol=">"): ... +class GtEq(Operator, func=op.ge, symbol=">="): ... +class Add(Operator, func=op.add, symbol="+"): ... +class Sub(SelectorOperator, func=op.sub, symbol="-"): ... +class Multiply(Operator, func=op.mul, symbol="*"): ... +class TrueDivide(Operator, func=op.truediv, symbol="/"): ... +class FloorDivide(Operator, func=op.floordiv, symbol="//"): ... +class Modulus(Operator, func=op.mod, symbol="%"): ... +class And(SelectorOperator, func=op.and_, symbol="&"): ... +class Or(SelectorOperator, func=op.or_, symbol="|"): ... +class ExclusiveOr(SelectorOperator, func=op.xor, symbol="^"): ... +# fmt: on diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ca6cf91a04..6f77674dff 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -4,16 +4,19 @@ from itertools import repeat from typing import TYPE_CHECKING, Literal -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable if TYPE_CHECKING: from collections.abc import Iterable, Sequence import pyarrow.compute as pc + from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Seq + from narwhals._plan.typing import Accessor, OneOrIterable, Seq from narwhals.typing import RankMethod +DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] + class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -184,7 +187,7 @@ def __repr__(self) -> str: @staticmethod def parse( - *, descending: bool | Iterable[bool], nulls_last: bool | Iterable[bool] + *, descending: OneOrIterable[bool], nulls_last: OneOrIterable[bool] ) -> SortMultipleOptions: desc = (descending,) if isinstance(descending, bool) else tuple(descending) nulls = (nulls_last,) if isinstance(nulls_last, bool) else tuple(nulls_last) @@ -263,3 +266,56 @@ def rolling_options( center=center, fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), ) + + +class _BaseIROptions(Immutable): + __slots__ = ("origin", "override_name") + origin: DispatchOrigin + override_name: str + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="") + + @classmethod + def renamed(cls, name: str, /) -> Self: + from narwhals._plan.common import replace + + return replace(cls.default(), override_name=name) + + @classmethod + def namespaced(cls, override_name: str = "", /) -> Self: + from narwhals._plan.common import replace + + return replace( + cls.default(), origin="__narwhals_namespace__", override_name=override_name + ) + + +class ExprIROptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "allow_dispatch") + allow_dispatch: bool + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", allow_dispatch=True) + + @staticmethod + def no_dispatch() -> ExprIROptions: + return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + + +class FunctionExprOptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "accessor_name") + accessor_name: Accessor | None + """Namespace accessor name, if any.""" + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", accessor_name=None) + + +FEOptions = FunctionExprOptions diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 821e7a338a..951dfa850e 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,27 +1,30 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings -from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe +from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar -from narwhals._utils import Version, _hasattr_static +from narwhals._utils import Version if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._plan import aggregation as agg, boolean, expr, functions as F + from narwhals._plan.boolean import IsBetween, IsFinite, IsNan, IsNull, Not from narwhals._plan.dummy import BaseFrame, DataFrame, Series - from narwhals._plan.expr import FunctionExpr, RangeExpr + from narwhals._plan.expr import BinaryExpr, FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange + from narwhals._plan.strings import ConcatStr + from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType - from narwhals.schema import Schema from narwhals.typing import ( ConcatMethod, Into1DArray, IntoDType, + IntoSchema, NonNestedLiteral, PythonLiteral, _1DArray, @@ -29,7 +32,6 @@ T = TypeVar("T") R_co = TypeVar("R_co", covariant=True) -OneOrIterable: TypeAlias = "T | Iterable[T]" LengthT = TypeVar("LengthT") NativeT_co = TypeVar("NativeT_co", covariant=True, default=Any) @@ -69,6 +71,22 @@ LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) +Ctx: TypeAlias = "ExprDispatch[FrameT_contra, R_co, NamespaceAny]" +"""Type of an unknown expression dispatch context. + +- `FrameT_contra`: Compliant data/lazyframe +- `R_co`: Upper bound return type of the context +""" + + +class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]): + def __narwhals_namespace__(self) -> NamespaceT_co: ... + + +def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co: + """Return the compliant namespace.""" + return obj.__narwhals_namespace__() + # NOTE: Unlike the version in `nw._utils`, here `.version` it is public class StoresVersion(Protocol): @@ -143,130 +161,11 @@ def _length_required( class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): - _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { - expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( - node, frame, name - ), - expr.Literal: lambda self, node, frame, name: self.__narwhals_namespace__().lit( - node, frame, name - ), - expr.Len: lambda self, node, frame, name: self.__narwhals_namespace__().len( - node, frame, name - ), - expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), - expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), - expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), - expr.Filter: lambda self, node, frame, name: self.filter(node, frame, name), - agg.First: lambda self, node, frame, name: self.first(node, frame, name), - agg.Last: lambda self, node, frame, name: self.last(node, frame, name), - agg.ArgMin: lambda self, node, frame, name: self.arg_min(node, frame, name), - agg.ArgMax: lambda self, node, frame, name: self.arg_max(node, frame, name), - agg.Sum: lambda self, node, frame, name: self.sum(node, frame, name), - agg.NUnique: lambda self, node, frame, name: self.n_unique(node, frame, name), - agg.Std: lambda self, node, frame, name: self.std(node, frame, name), - agg.Var: lambda self, node, frame, name: self.var(node, frame, name), - agg.Quantile: lambda self, node, frame, name: self.quantile(node, frame, name), - agg.Count: lambda self, node, frame, name: self.count(node, frame, name), - agg.Max: lambda self, node, frame, name: self.max(node, frame, name), - agg.Mean: lambda self, node, frame, name: self.mean(node, frame, name), - agg.Median: lambda self, node, frame, name: self.median(node, frame, name), - agg.Min: lambda self, node, frame, name: self.min(node, frame, name), - expr.BinaryExpr: lambda self, node, frame, name: self.binary_expr( - node, frame, name - ), - expr.RollingExpr: lambda self, node, frame, name: self.rolling_expr( - node, frame, name - ), - expr.AnonymousExpr: lambda self, node, frame, name: self.map_batches( - node, frame, name - ), - expr.FunctionExpr: lambda self, node, frame, name: self._dispatch_function( - node, frame, name - ), - # NOTE: Keeping it simple for now - # When adding other `*_range` functions, this should instead map to `range_expr` - expr.RangeExpr: lambda self, - node, - frame, - name: self.__narwhals_namespace__().int_range(node, frame, name), - expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered( - node, frame, name - ), - expr.WindowExpr: lambda self, node, frame, name: self.over(node, frame, name), - expr.Ternary: lambda self, node, frame, name: self.ternary_expr( - node, frame, name - ), - } - _DISPATCH_FUNCTION: ClassVar[ - Mapping[type[Function], Callable[[Any, FunctionExpr, Any, str], Any]] - ] = { - boolean.AnyHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().any_horizontal(node, frame, name), - boolean.AllHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().all_horizontal(node, frame, name), - F.SumHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().sum_horizontal(node, frame, name), - F.MinHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().min_horizontal(node, frame, name), - F.MaxHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().max_horizontal(node, frame, name), - F.MeanHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().mean_horizontal(node, frame, name), - strings.ConcatHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().concat_str(node, frame, name), - F.Pow: lambda self, node, frame, name: self.pow(node, frame, name), - F.FillNull: lambda self, node, frame, name: self.fill_null(node, frame, name), - boolean.IsBetween: lambda self, node, frame, name: self.is_between( - node, frame, name - ), - boolean.IsFinite: lambda self, node, frame, name: self.is_finite( - node, frame, name - ), - boolean.IsNan: lambda self, node, frame, name: self.is_nan(node, frame, name), - boolean.IsNull: lambda self, node, frame, name: self.is_null(node, frame, name), - boolean.Not: lambda self, node, frame, name: self.not_(node, frame, name), - boolean.Any: lambda self, node, frame, name: self.any(node, frame, name), - boolean.All: lambda self, node, frame, name: self.all(node, frame, name), - } - - def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: - if (method := self._DISPATCH.get(node.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - - def _dispatch_function( - self, node: FunctionExpr, frame: FrameT_contra, name: str - ) -> R_co: - fn = node.function - if (method := self._DISPATCH_FUNCTION.get(fn.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {fn.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame.version - return obj._dispatch(node, frame, name) + return node.dispatch(obj, frame, name) @classmethod def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: @@ -284,41 +183,35 @@ class CompliantExpr(StoresVersion, Protocol[FrameT_contra, SeriesT_co]): @property def name(self) -> str: ... - @classmethod def from_native( cls, native: Any, name: str = "", /, version: Version = Version.MAIN ) -> Self: ... - def _with_native(self, native: Any, name: str, /) -> Self: return self.from_native(native, name or self.name, self.version) # series & scalar def cast(self, node: expr.Cast, frame: FrameT_contra, name: str) -> Self: ... def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Self: ... - def not_( - self, node: FunctionExpr[boolean.Not], frame: FrameT_contra, name: str - ) -> Self: ... + def not_(self, node: FunctionExpr[Not], frame: FrameT_contra, name: str) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... def is_between( - self, node: FunctionExpr[boolean.IsBetween], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsBetween], frame: FrameT_contra, name: str ) -> Self: ... def is_finite( - self, node: FunctionExpr[boolean.IsFinite], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsFinite], frame: FrameT_contra, name: str ) -> Self: ... def is_nan( - self, node: FunctionExpr[boolean.IsNan], frame: FrameT_contra, name: str + self, node: FunctionExpr[IsNan], frame: FrameT_contra, name: str ) -> Self: ... def is_null( - self, node: FunctionExpr[boolean.IsNull], frame: FrameT_contra, name: str - ) -> Self: ... - def binary_expr( - self, node: expr.BinaryExpr, frame: FrameT_contra, name: str + self, node: FunctionExpr[IsNull], frame: FrameT_contra, name: str ) -> Self: ... + def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def ternary_expr( - self, node: expr.Ternary, frame: FrameT_contra, name: str + self, node: expr.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` @@ -406,7 +299,6 @@ def from_python( dtype: IntoDType | None, version: Version, ) -> Self: ... - def _with_evaluated(self, evaluated: Any, name: str) -> Self: """Expr is based on a series having these via accessors, but a scalar needs to keep passing through.""" cls = type(self) @@ -526,7 +418,7 @@ def concat( self, items: Iterable[ConcatT2], *, how: Literal["vertical"] ) -> ConcatT2: ... def concat( - self, items: Iterable[ConcatT1] | Iterable[ConcatT2], *, how: ConcatMethod + self, items: Iterable[ConcatT1 | ConcatT2], *, how: ConcatMethod ) -> ConcatT1 | ConcatT2: ... @@ -536,7 +428,7 @@ def _concat_diagonal(self, items: Iterable[ConcatT1], /) -> ConcatT1: ... # but that is only available privately def _concat_horizontal(self, items: Iterable[ConcatT1 | ConcatT2], /) -> ConcatT1: ... def _concat_vertical( - self, items: Iterable[ConcatT1] | Iterable[ConcatT2], / + self, items: Iterable[ConcatT1 | ConcatT2], / ) -> ConcatT1 | ConcatT2: ... @@ -571,7 +463,7 @@ def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def concat_str( - self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str + self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def int_range( self, node: RangeExpr[IntRange], frame: FrameT, name: str @@ -605,17 +497,9 @@ def lit( def lit( self, node: expr.Literal[Series[Any]], frame: EagerDataFrameT, name: str ) -> EagerExprT_co: ... - @overload - def lit( - self, - node: expr.Literal[NonNestedLiteral] | expr.Literal[Series[Any]], - frame: EagerDataFrameT, - name: str, - ) -> EagerExprT_co | EagerScalarT_co: ... def lit( self, node: expr.Literal[Any], frame: EagerDataFrameT, name: str ) -> EagerExprT_co | EagerScalarT_co: ... - def len(self, node: expr.Len, frame: EagerDataFrameT, name: str) -> EagerScalarT_co: return self._scalar.from_python( len(frame), name or node.name, dtype=None, version=frame.version @@ -645,7 +529,6 @@ def native(self) -> NativeFrameT: @property def columns(self) -> list[str]: ... def to_narwhals(self) -> BaseFrame[NativeFrameT]: ... - @classmethod def from_native(cls, native: NativeFrameT, /, version: Version) -> Self: obj = cls.__new__(cls) @@ -672,15 +555,9 @@ class CompliantDataFrame( ): @classmethod def from_dict( - cls, - data: Mapping[str, Any], - /, - *, - schema: Mapping[str, DType] | Schema | None = None, + cls, data: Mapping[str, Any], /, *, schema: IntoSchema | None = None ) -> Self: ... - def to_narwhals(self) -> DataFrame[NativeDataFrameT, NativeSeriesT]: ... - @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... @overload @@ -689,11 +566,9 @@ def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def to_dict( self, *, as_series: bool ) -> dict[str, SeriesT] | dict[str, list[Any]]: ... - def __len__(self) -> int: ... def with_row_index(self, name: str) -> Self: ... @@ -704,12 +579,10 @@ class EagerDataFrame( ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): @@ -725,7 +598,6 @@ def native(self) -> NativeSeriesT: @property def dtype(self) -> DType: ... - @property def name(self) -> str: return self._name @@ -739,9 +611,6 @@ def to_narwhals(self) -> Series[NativeSeriesT]: def from_native( cls, native: NativeSeriesT, name: str = "", /, *, version: Version = Version.MAIN ) -> Self: - name = name or ( - getattr(native, "name", name) if _hasattr_static(native, "name") else name - ) obj = cls.__new__(cls) obj._native = native obj._name = name @@ -752,7 +621,6 @@ def from_native( def from_numpy( cls, data: Into1DArray, name: str = "", /, *, version: Version = Version.MAIN ) -> Self: ... - @classmethod def from_iterable( cls, @@ -762,7 +630,6 @@ def from_iterable( name: str = "", dtype: IntoDType | None = None, ) -> Self: ... - def _with_native(self, native: NativeSeriesT) -> Self: return self.from_native(native, self.name, version=self.version) diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index 4414afabf7..4f8e49b531 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprIR, Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing_extensions import Self @@ -12,7 +12,7 @@ from narwhals.dtypes import IntegerType -class RangeFunction(Function): +class RangeFunction(Function, config=FEOptions.namespaced()): def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: from narwhals._plan.expr import RangeExpr diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 17b8416285..69c1b5a2b3 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,7 +7,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload -from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable, NamedIR +from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._plan.common import NamedIR from narwhals.dtypes import Unknown if TYPE_CHECKING: @@ -95,8 +96,7 @@ def _from_mapping(mapping: MappingProxyType[str, DType], /) -> FrozenSchema: @staticmethod def _from_hash_safe(items: _FrozenSchemaHash, /) -> FrozenSchema: - clone = MappingProxyType(dict(items)) - return FrozenSchema._from_mapping(clone) + return FrozenSchema._from_mapping(MappingProxyType(dict(items))) def items(self) -> ItemsView[str, DType]: return self._mapping.items() diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 3cd7666ddc..4aa5f58a3d 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -9,16 +9,18 @@ import re from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, flatten_hash_safe +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import flatten_hash_safe from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterator from datetime import timezone from typing import TypeVar from narwhals._plan import dummy from narwhals._plan.expr import RootSelector + from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType from narwhals.typing import TimeUnit @@ -50,9 +52,7 @@ class ByDType(Selector): dtypes: frozenset[DType | type[DType]] @staticmethod - def from_dtypes( - *dtypes: DType | type[DType] | Iterable[DType | type[DType]], - ) -> ByDType: + def from_dtypes(*dtypes: OneOrIterable[DType | type[DType]]) -> ByDType: return ByDType(dtypes=frozenset(flatten_hash_safe(dtypes))) def __repr__(self) -> str: @@ -95,8 +95,8 @@ class Datetime(Selector): @staticmethod def from_time_unit_and_time_zone( - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, + time_unit: OneOrIterable[TimeUnit] | None, + time_zone: OneOrIterable[str | timezone | None], /, ) -> Datetime: units, zones = _parse_time_unit_and_time_zone(time_unit, time_zone) @@ -125,11 +125,10 @@ def from_string(pattern: str, /) -> Matches: return Matches(pattern=re.compile(pattern)) @staticmethod - def from_names(*names: str | Iterable[str]) -> Matches: + def from_names(*names: OneOrIterable[str]) -> Matches: """Implements `cs.by_name` to support `__r__` with column selections.""" it: Iterator[str] = flatten_hash_safe(names) - pattern = f"^({'|'.join(re.escape(name) for name in it)})$" - return Matches.from_string(pattern) + return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$") def __repr__(self) -> str: return f"ncs.matches(pattern={self.pattern.pattern!r})" @@ -158,13 +157,11 @@ def all() -> dummy.Selector: return All().to_selector().to_narwhals() -def by_dtype( - *dtypes: DType | type[DType] | Iterable[DType | type[DType]], -) -> dummy.Selector: +def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> dummy.Selector: return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() -def by_name(*names: str | Iterable[str]) -> dummy.Selector: +def by_name(*names: OneOrIterable[str]) -> dummy.Selector: return Matches.from_names(*names).to_selector().to_narwhals() @@ -177,8 +174,8 @@ def categorical() -> dummy.Selector: def datetime( - time_unit: TimeUnit | Iterable[TimeUnit] | None = None, - time_zone: str | timezone | Iterable[str | timezone | None] | None = ("*", None), + time_unit: OneOrIterable[TimeUnit] | None = None, + time_zone: OneOrIterable[str | timezone | None] = ("*", None), ) -> dummy.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 8a1789b079..4c1f4af303 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -1,20 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan.common import ExprNamespace, Function, HorizontalFunction, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr +# fmt: off class StringFunction(Function, accessor="str", options=FunctionOptions.elementwise): ... - - -class ConcatHorizontal(StringFunction, options=FunctionOptions.horizontal): - """`nw.functions.concat_str`.""" - +class LenChars(StringFunction): ... +class ToLowercase(StringFunction): ... +class ToUppercase(StringFunction): ... +# fmt: on +class ConcatStr(HorizontalFunction, StringFunction): __slots__ = ("ignore_nulls", "separator") separator: str ignore_nulls: bool @@ -31,9 +32,6 @@ class EndsWith(StringFunction): suffix: str -class LenChars(StringFunction): ... - - class Replace(StringFunction): __slots__ = ("literal", "n", "pattern", "value") pattern: str @@ -75,15 +73,13 @@ class ToDatetime(StringFunction): format: str | None -class ToLowercase(StringFunction): ... - - -class ToUppercase(StringFunction): ... - - class IRStringNamespace(IRNamespace): - def len_chars(self) -> LenChars: - return LenChars() + len_chars: ClassVar = LenChars + to_lowercase: ClassVar = ToUppercase + to_uppercase: ClassVar = ToLowercase + split: ClassVar = Split + starts_with: ClassVar = StartsWith + ends_with: ClassVar = EndsWith def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 @@ -98,12 +94,6 @@ def replace_all( def strip_chars(self, characters: str | None = None) -> StripChars: return StripChars(characters=characters) - def starts_with(self, prefix: str) -> StartsWith: - return StartsWith(prefix=prefix) - - def ends_with(self, suffix: str) -> EndsWith: - return EndsWith(suffix=suffix) - def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) @@ -116,18 +106,9 @@ def head(self, n: int = 5) -> Slice: def tail(self, n: int = 5) -> Slice: return self.slice(-n) - def split(self, by: str) -> Split: - return Split(by=by) - def to_datetime(self, format: str | None = None) -> ToDatetime: return ToDatetime(format=format) - def to_lowercase(self) -> ToUppercase: - return ToUppercase() - - def to_uppercase(self) -> ToLowercase: - return ToLowercase() - class ExprStringNamespace(ExprNamespace[IRStringNamespace]): @property @@ -149,10 +130,10 @@ def strip_chars(self, characters: str | None = None) -> Expr: return self._with_unary(self._ir.strip_chars(characters)) def starts_with(self, prefix: str) -> Expr: - return self._with_unary(self._ir.starts_with(prefix)) + return self._with_unary(self._ir.starts_with(prefix=prefix)) def ends_with(self, suffix: str) -> Expr: - return self._with_unary(self._ir.ends_with(suffix)) + return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) @@ -167,7 +148,7 @@ def tail(self, n: int = 5) -> Expr: return self._with_unary(self._ir.tail(n)) def split(self, by: str) -> Expr: - return self._with_unary(self._ir.split(by)) + return self._with_unary(self._ir.split(by=by)) def to_datetime(self, format: str | None = None) -> Expr: return self._with_unary(self._ir.to_datetime(format)) diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index d91fef6458..2a3eca0b27 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr @@ -12,9 +12,9 @@ class StructFunction(Function, accessor="struct"): ... -class FieldByName(StructFunction, options=FunctionOptions.elementwise): - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" - +class FieldByName( + StructFunction, options=FunctionOptions.elementwise, config=FEOptions.renamed("field") +): __slots__ = ("name",) name: str @@ -23,8 +23,7 @@ def __repr__(self) -> str: class IRStructNamespace(IRNamespace): - def field(self, name: str) -> FieldByName: - return FieldByName(name=name) + field: ClassVar = FieldByName class ExprStructNamespace(ExprNamespace[IRStructNamespace]): @@ -33,4 +32,4 @@ def _ir_namespace(self) -> type[IRStructNamespace]: return IRStructNamespace def field(self, name: str) -> Expr: - return self._with_unary(self._ir.field(name)) + return self._with_unary(self._ir.field(name=name)) diff --git a/narwhals/_plan/temporal.py b/narwhals/_plan/temporal.py index f6a74587f7..bd21388728 100644 --- a/narwhals/_plan/temporal.py +++ b/narwhals/_plan/temporal.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal from narwhals._duration import Interval from narwhals._plan.common import ExprNamespace, Function, IRNamespace @@ -20,60 +20,26 @@ def _is_polars_time_unit(obj: Any) -> TypeIs[PolarsTimeUnit]: return obj in {"ns", "us", "ms"} +# fmt: off class TemporalFunction(Function, accessor="dt", options=FunctionOptions.elementwise): ... - - class Date(TemporalFunction): ... - - class Year(TemporalFunction): ... - - class Month(TemporalFunction): ... - - class Day(TemporalFunction): ... - - class Hour(TemporalFunction): ... - - class Minute(TemporalFunction): ... - - class Second(TemporalFunction): ... - - class Millisecond(TemporalFunction): ... - - class Microsecond(TemporalFunction): ... - - class Nanosecond(TemporalFunction): ... - - class OrdinalDay(TemporalFunction): ... - - class WeekDay(TemporalFunction): ... - - class TotalMinutes(TemporalFunction): ... - - class TotalSeconds(TemporalFunction): ... - - class TotalMilliseconds(TemporalFunction): ... - - class TotalMicroseconds(TemporalFunction): ... - - class TotalNanoseconds(TemporalFunction): ... - - +# fmt: on class ToString(TemporalFunction): __slots__ = ("format",) format: str @@ -94,7 +60,7 @@ class Timestamp(TemporalFunction): time_unit: PolarsTimeUnit @staticmethod - def from_time_unit(time_unit: TimeUnit, /) -> Timestamp: + def from_time_unit(time_unit: TimeUnit = "us", /) -> Timestamp: if not _is_polars_time_unit(time_unit): msg = f"invalid `time_unit` \n\nExpected one of ['ns', 'us', 'ms'], got {time_unit!r}." raise ValueError(msg) @@ -119,71 +85,28 @@ def from_interval(every: Interval, /) -> Truncate: class IRDateTimeNamespace(IRNamespace): - def date(self) -> Date: - return Date() - - def year(self) -> Year: - return Year() - - def month(self) -> Month: - return Month() - - def day(self) -> Day: - return Day() - - def hour(self) -> Hour: - return Hour() - - def minute(self) -> Minute: - return Minute() - - def second(self) -> Second: - return Second() - - def millisecond(self) -> Millisecond: - return Millisecond() - - def microsecond(self) -> Microsecond: - return Microsecond() - - def nanosecond(self) -> Nanosecond: - return Nanosecond() - - def ordinal_day(self) -> OrdinalDay: - return OrdinalDay() - - def weekday(self) -> WeekDay: - return WeekDay() - - def total_minutes(self) -> TotalMinutes: - return TotalMinutes() - - def total_seconds(self) -> TotalSeconds: - return TotalSeconds() - - def total_milliseconds(self) -> TotalMilliseconds: - return TotalMilliseconds() - - def total_microseconds(self) -> TotalMicroseconds: - return TotalMicroseconds() - - def total_nanoseconds(self) -> TotalNanoseconds: - return TotalNanoseconds() - - def to_string(self, format: str) -> ToString: - return ToString(format=format) - - def replace_time_zone(self, time_zone: str | None) -> ReplaceTimeZone: - return ReplaceTimeZone(time_zone=time_zone) - - def convert_time_zone(self, time_zone: str) -> ConvertTimeZone: - return ConvertTimeZone(time_zone=time_zone) - - def timestamp(self, time_unit: TimeUnit = "us") -> Timestamp: - return Timestamp.from_time_unit(time_unit) - - def truncate(self, every: str) -> Truncate: - return Truncate.from_string(every) + date: ClassVar = Date + year: ClassVar = Year + month: ClassVar = Month + day: ClassVar = Day + hour: ClassVar = Hour + minute: ClassVar = Minute + second: ClassVar = Second + millisecond: ClassVar = Millisecond + microsecond: ClassVar = Microsecond + nanosecond: ClassVar = Nanosecond + ordinal_day: ClassVar = OrdinalDay + weekday: ClassVar = WeekDay + total_minutes: ClassVar = TotalMinutes + total_seconds: ClassVar = TotalSeconds + total_milliseconds: ClassVar = TotalMilliseconds + total_microseconds: ClassVar = TotalMicroseconds + total_nanoseconds: ClassVar = TotalNanoseconds + to_string: ClassVar = ToString + replace_time_zone: ClassVar = ReplaceTimeZone + convert_time_zone: ClassVar = ConvertTimeZone + truncate: ClassVar = staticmethod(Truncate.from_string) + timestamp: ClassVar = staticmethod(Timestamp.from_time_unit) class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): @@ -252,7 +175,7 @@ def convert_time_zone(self, time_zone: str) -> Expr: return self._with_unary(self._ir.convert_time_zone(time_zone=time_zone)) def timestamp(self, time_unit: TimeUnit = "us") -> Expr: - return self._with_unary(self._ir.timestamp(time_unit=time_unit)) + return self._with_unary(self._ir.timestamp(time_unit)) def truncate(self, every: str) -> Expr: - return self._with_unary(self._ir.truncate(every=every)) + return self._with_unary(self._ir.truncate(every)) diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index 0b8b884b5f..251489c68d 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -47,10 +47,8 @@ RollingT = TypeVar("RollingT", bound="RollingWindow", default="RollingWindow") RangeT = TypeVar("RangeT", bound="RangeFunction", default="RangeFunction") LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR") -LeftT2 = TypeVar("LeftT2", bound="ExprIR", default="ExprIR") OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator") RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR") -RightT2 = TypeVar("RightT2", bound="ExprIR", default="ExprIR") OperatorFn: TypeAlias = "t.Callable[[t.Any, t.Any], t.Any]" ExprIRT = TypeVar("ExprIRT", bound="ExprIR", default="ExprIR") ExprIRT2 = TypeVar("ExprIRT2", bound="ExprIR", default="ExprIR") @@ -96,3 +94,4 @@ IntoExprColumn: TypeAlias = "Expr | Series[t.Any] | str" IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn" +OneOrIterable: TypeAlias = "T | t.Iterable[T]" diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index d264f39733..62e0da3d2a 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import Immutable, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.dummy import Expr from narwhals._plan.expr_parsing import ( parse_into_expr_ir, @@ -10,11 +11,9 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable - from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Ternary - from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq + from narwhals._plan.expr import TernaryExpr + from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq class When(Immutable): @@ -39,7 +38,7 @@ class Then(Immutable, Expr): statement: ExprIR def when( - self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> ChainedWhen: condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( @@ -84,7 +83,7 @@ class ChainedThen(Immutable, Expr): statements: Seq[ExprIR] def when( - self, *predicates: IntoExprColumn | Iterable[IntoExprColumn], **constraints: Any + self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any ) -> ChainedWhen: condition = parse_predicates_constraints_into_expr_ir(*predicates, **constraints) return ChainedWhen( @@ -96,10 +95,8 @@ def otherwise(self, statement: IntoExpr, /) -> Expr: def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR: otherwise = parse_into_expr_ir(statement) - it_conditions = reversed(self.conditions) - it_statements = reversed(self.statements) - for e in it_conditions: - otherwise = ternary_expr(e, next(it_statements), otherwise) + for cond, stmt in zip(reversed(self.conditions), reversed(self.statements)): + otherwise = ternary_expr(cond, stmt, otherwise) return otherwise @property @@ -116,7 +113,7 @@ def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] return super().__eq__(value) -def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary: - from narwhals._plan.expr import Ternary +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: + from narwhals._plan.expr import TernaryExpr - return Ternary(predicate=predicate, truthy=truthy, falsy=falsy) + return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index f575d9d303..fd27743948 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr, is_window_expr +from narwhals._plan._guards import is_function_expr, is_window_expr +from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( over_elementwise_error, over_nested_error, diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 7d7f1f6248..dc548968a4 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -11,7 +11,7 @@ import narwhals as nw from narwhals._plan import demo as nwd, selectors as ndcs -from narwhals._plan.common import is_expr +from narwhals._plan._guards import is_expr from narwhals._plan.dummy import DataFrame from narwhals._utils import Version from narwhals.exceptions import ComputeError diff --git a/tests/plan/expr_rewrites_test.py b/tests/plan/expr_rewrites_test.py index 8e5dd0f29c..740d966818 100644 --- a/tests/plan/expr_rewrites_test.py +++ b/tests/plan/expr_rewrites_test.py @@ -6,7 +6,8 @@ import narwhals as nw from narwhals._plan import demo as nwd, expr_parsing as parse, selectors as ndcs -from narwhals._plan.common import ExprIR, NamedIR, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan.common import ExprIR, NamedIR from narwhals._plan.expr import WindowExpr from narwhals._plan.expr_rewrites import ( rewrite_all, diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 3c5e97439e..6f9d0450ad 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -6,7 +6,7 @@ import pytest -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable class Empty(Immutable): ... diff --git a/tests/plan/utils.py b/tests/plan/utils.py index 6b818f82df..4eaf98db9f 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprIR, NamedIR, is_expr +from narwhals._plan._guards import is_expr +from narwhals._plan.common import ExprIR, NamedIR if TYPE_CHECKING: from typing_extensions import LiteralString