diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py new file mode 100644 index 0000000000..248e1b5631 --- /dev/null +++ b/narwhals/_plan/_immutable.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypeVar + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Any, Callable + + from typing_extensions import Never, Self, dataclass_transform + +else: + # https://docs.python.org/3/library/typing.html#typing.dataclass_transform + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + frozen_default: bool = False, + field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), + **kwargs: Any, + ) -> Callable[[T], T]: + def decorator(cls_or_fn: T) -> T: + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "frozen_default": frozen_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + + return decorator + + +T = TypeVar("T") +_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" + + +@dataclass_transform(kw_only_default=True, frozen_default=True) +class Immutable: + """A poor man's frozen dataclass. + + - Keyword-only constructor (IDE supported) + - Manual `__slots__` required + - Compatible with [`copy.replace`] + - No handling for default arguments + + [`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace + """ + + __slots__ = (_IMMUTABLE_HASH_NAME,) + __immutable_hash_value__: int + + @property + def __immutable_keys__(self) -> Iterator[str]: + slots: tuple[str, ...] = self.__slots__ + for name in slots: + if name != _IMMUTABLE_HASH_NAME: + yield name + + @property + def __immutable_values__(self) -> Iterator[Any]: + for name in self.__immutable_keys__: + yield getattr(self, name) + + @property + def __immutable_items__(self) -> Iterator[tuple[str, Any]]: + for name in self.__immutable_keys__: + yield name, getattr(self, name) + + @property + def __immutable_hash__(self) -> int: + if hasattr(self, _IMMUTABLE_HASH_NAME): + return self.__immutable_hash_value__ + hash_value = hash((self.__class__, *self.__immutable_values__)) + object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) + return self.__immutable_hash_value__ + + def __setattr__(self, name: str, value: Never) -> Never: + msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." + raise AttributeError(msg) + + def __replace__(self, **changes: Any) -> Self: + """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 + if len(changes) == 1: + k_new, v_new = next(iter(changes.items())) + # NOTE: Will trigger an attribute error if invalid name + if getattr(self, k_new) == v_new: + return self + changed = dict(self.__immutable_items__) + # Now we *don't* need to check the key is valid + changed[k_new] = v_new + else: + changed = dict(self.__immutable_items__) + changed |= changes + return type(self)(**changed) + + def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: + super().__init_subclass__(*args, **kwds) + if cls.__slots__: + ... + else: + cls.__slots__ = () + + def __hash__(self) -> int: + return self.__immutable_hash__ + + def __eq__(self, other: object) -> bool: + if self is other: + return True + if type(self) is not type(other): + return False + return all( + getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ + ) + + def __str__(self) -> str: + fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) + return f"{type(self).__name__}({fields})" + + def __init__(self, **kwds: Any) -> None: + required: set[str] = set(self.__immutable_keys__) + if not required and not kwds: + # NOTE: Fastpath for empty slots + ... + elif required == set(kwds): + for name, value in kwds.items(): + object.__setattr__(self, name, value) + elif missing := required.difference(kwds): + msg = ( + f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" + f"but missing values for {sorted(missing)!r}" + ) + raise TypeError(msg) + else: + extra = set(kwds).difference(required) + msg = ( + f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" + f"but got unknown arguments {sorted(extra)!r}" + ) + raise TypeError(msg) + + +def _field_str(name: str, value: Any) -> str: + if isinstance(value, tuple): + inner = ", ".join(f"{v}" for v in value) + return f"{name}=[{inner}]" + if isinstance(value, str): + return f"{name}={value!r}" + return f"{name}={value}" diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index cbc5f600b4..e6d1389a29 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -9,7 +9,7 @@ from narwhals._plan.arrow import functions as fn from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.common import ExprIR -from narwhals._plan.protocols import EagerDataFrame +from narwhals._plan.protocols import EagerDataFrame, namespace from narwhals._utils import Version if t.TYPE_CHECKING: @@ -89,7 +89,7 @@ def to_dict( return {ser.name: ser.to_list() for ser in it} def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]: - ns = self.__narwhals_namespace__() + ns = namespace(self) from_named_ir = ns._expr.from_named_ir yield from ns._expr.align(from_named_ir(e, self) for e in nodes) diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index ef6e0164ed..fbf758c51d 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -11,7 +11,7 @@ from narwhals._plan.arrow.series import ArrowSeries from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co from narwhals._plan.common import ExprIR, NamedIR, into_dtype -from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch +from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace from narwhals._utils import ( Implementation, Version, @@ -53,7 +53,7 @@ FunctionExpr, OrderedWindowExpr, RollingExpr, - Ternary, + TernaryExpr, WindowExpr, ) from narwhals._plan.functions import FillNull, Pow @@ -76,32 +76,32 @@ def __narwhals_namespace__(self) -> ArrowNamespace: def _with_native(self, native: Any, name: str, /) -> StoresNativeT_co: ... def cast(self, node: expr.Cast, frame: ArrowDataFrame, name: str) -> StoresNativeT_co: data_type = narwhals_to_native_dtype(node.dtype, frame.version) - native = self._dispatch(node.expr, frame, name).native + native = node.expr.dispatch(self, frame, name).native return self._with_native(fn.cast(native, data_type), name) def pow( self, node: FunctionExpr[Pow], frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: base, exponent = node.function.unwrap_input(node) - base_ = self._dispatch(base, frame, "base").native - exponent_ = self._dispatch(exponent, frame, "exponent").native + base_ = base.dispatch(self, frame, "base").native + exponent_ = exponent.dispatch(self, frame, "exponent").native return self._with_native(pc.power(base_, exponent_), name) def fill_null( self, node: FunctionExpr[FillNull], frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: expr, value = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - value_ = self._dispatch(value, frame, "value").native + native = expr.dispatch(self, frame, name).native + value_ = value.dispatch(self, frame, "value").native return self._with_native(pc.fill_null(native, value_), name) def is_between( self, node: FunctionExpr[IsBetween], frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: expr, lower_bound, upper_bound = node.function.unwrap_input(node) - native = self._dispatch(expr, frame, name).native - lower = self._dispatch(lower_bound, frame, "lower").native - upper = self._dispatch(upper_bound, frame, "upper").native + native = expr.dispatch(self, frame, name).native + lower = lower_bound.dispatch(self, frame, "lower").native + upper = upper_bound.dispatch(self, frame, "upper").native result = fn.is_between(native, lower, upper, node.function.closed) return self._with_native(result, name) @@ -111,7 +111,7 @@ def _unary_function( def func( node: FunctionExpr[Any], frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: - native = self._dispatch(node.input[0], frame, name).native + native = node.input[0].dispatch(self, frame, name).native return self._with_native(fn_native(native), name) return func @@ -150,18 +150,18 @@ def binary_expr( self, node: BinaryExpr, frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: lhs, rhs = ( - self._dispatch(node.left, frame, name), - self._dispatch(node.right, frame, name), + node.left.dispatch(self, frame, name), + node.right.dispatch(self, frame, name), ) result = fn.binary(lhs.native, node.op.__class__, rhs.native) return self._with_native(result, name) def ternary_expr( - self, node: Ternary, frame: ArrowDataFrame, name: str + self, node: TernaryExpr, frame: ArrowDataFrame, name: str ) -> StoresNativeT_co: - when = self._dispatch(node.predicate, frame, name) - then = self._dispatch(node.truthy, frame, name) - otherwise = self._dispatch(node.falsy, frame, name) + when = node.predicate.dispatch(self, frame, name) + then = node.truthy.dispatch(self, frame, name) + otherwise = node.falsy.dispatch(self, frame, name) result = pc.if_else(when.native, then.native, otherwise.native) return self._with_native(result, name) @@ -216,7 +216,7 @@ def _dispatch_expr( Mainly for the benefit of a type checker, but the equivalent `ArrowScalar._dispatch_expr` will raise if the assumption fails. """ - return self._dispatch(node, frame, name).to_series() + return node.dispatch(self, frame, name).to_series() @property def native(self) -> ChunkedArrayAny: @@ -245,8 +245,7 @@ def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowE self._dispatch_expr(e, frame, f"_{idx}") for idx, e in enumerate(node.by) ) - ns = self.__narwhals_namespace__() - df = ns._concat_horizontal((series, *by)) + df = namespace(self)._concat_horizontal((series, *by)) names = df.columns[1:] indices = pc.sort_indices(df.native, options=node.options.to_arrow(names)) result: ChunkedArrayAny = df.native.column(0).take(indices) @@ -351,7 +350,7 @@ def over_ordered( options = node.sort_options.to_multiple(len(node.order_by)) idx_name = generate_temporary_column_name(8, frame.columns) sorted_context = frame.with_row_index(idx_name).sort(sort_by, options) - evaluated = self._dispatch(node.expr, sorted_context.drop([idx_name]), name) + evaluated = node.expr.dispatch(self, sorted_context.drop([idx_name]), name) if isinstance(evaluated, ArrowScalar): # NOTE: We're already sorted, defer broadcasting to the outer context # Wouldn't be suitable for partitions, but will be fine here @@ -479,7 +478,7 @@ def var(self, node: Var, frame: ArrowDataFrame, name: str) -> ArrowScalar: return self._with_native(pa.scalar(None, pa.null()), name) def count(self, node: Count, frame: ArrowDataFrame, name: str) -> ArrowScalar: - native = self._dispatch(node.expr, frame, name).native + native = node.expr.dispatch(self, frame, name).native return self._with_native(pa.scalar(1 if native.is_valid else 0), name) filter = not_implemented() diff --git a/narwhals/_plan/arrow/namespace.py b/narwhals/_plan/arrow/namespace.py index e941727d6b..51730a0d7c 100644 --- a/narwhals/_plan/arrow/namespace.py +++ b/narwhals/_plan/arrow/namespace.py @@ -27,7 +27,7 @@ from narwhals._plan.dummy import Series from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.ranges import IntRange - from narwhals._plan.strings import ConcatHorizontal + from narwhals._plan.strings import ConcatStr from narwhals.typing import ConcatMethod, NonNestedLiteral, PythonLiteral @@ -162,7 +162,7 @@ def mean_horizontal( return self._expr.from_native(result, name, self.version) def concat_str( - self, node: FunctionExpr[ConcatHorizontal], frame: ArrowDataFrame, name: str + self, node: FunctionExpr[ConcatStr], frame: ArrowDataFrame, name: str ) -> ArrowExpr | ArrowScalar: exprs = (self._expr.from_ir(e, frame, name) for e in node.input) aligned = (ser.native for ser in self._expr.align(exprs)) diff --git a/narwhals/_plan/boolean.py b/narwhals/_plan/boolean.py index 7a3902cb6b..6b469a53f0 100644 --- a/narwhals/_plan/boolean.py +++ b/narwhals/_plan/boolean.py @@ -4,8 +4,8 @@ # - Any import typing as t -from narwhals._plan.common import Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.common import Function, HorizontalFunction +from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar if t.TYPE_CHECKING: @@ -21,22 +21,22 @@ ExprT = TypeVar("ExprT", bound="ExprIR", default="ExprIR") -class BooleanFunction(Function): ... +class BooleanFunction(Function, options=FunctionOptions.elementwise): ... class All(BooleanFunction, options=FunctionOptions.aggregation): ... -class AllHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... +class AllHorizontal(HorizontalFunction, BooleanFunction): ... class Any(BooleanFunction, options=FunctionOptions.aggregation): ... -class AnyHorizontal(BooleanFunction, options=FunctionOptions.horizontal): ... +class AnyHorizontal(HorizontalFunction, BooleanFunction): ... -class IsBetween(BooleanFunction, options=FunctionOptions.elementwise): +class IsBetween(BooleanFunction): """N-ary (expr, lower_bound, upper_bound).""" __slots__ = ("closed",) @@ -50,13 +50,13 @@ def unwrap_input(self, node: FunctionExpr[Self], /) -> tuple[ExprIR, ExprIR, Exp class IsDuplicated(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsFinite(BooleanFunction, options=FunctionOptions.elementwise): ... +class IsFinite(BooleanFunction): ... class IsFirstDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsIn(BooleanFunction, t.Generic[OtherT], options=FunctionOptions.elementwise): +class IsIn(BooleanFunction, t.Generic[OtherT]): __slots__ = ("other",) other: OtherT @@ -95,14 +95,13 @@ def __init__(self, *, other: ExprT) -> None: class IsLastDistinct(BooleanFunction, options=FunctionOptions.length_preserving): ... -class IsNan(BooleanFunction, options=FunctionOptions.elementwise): ... +class IsNan(BooleanFunction): ... -class IsNull(BooleanFunction, options=FunctionOptions.elementwise): ... +class IsNull(BooleanFunction): ... class IsUnique(BooleanFunction, options=FunctionOptions.length_preserving): ... -class Not(BooleanFunction, options=FunctionOptions.elementwise): - """`__invert__`.""" +class Not(BooleanFunction, config=FEOptions.renamed("not_")): ... diff --git a/narwhals/_plan/categorical.py b/narwhals/_plan/categorical.py index 7fb58367f9..e698d381fe 100644 --- a/narwhals/_plan/categorical.py +++ b/narwhals/_plan/categorical.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr @@ -12,7 +11,7 @@ class CategoricalFunction(Function, accessor="cat"): ... -class GetCategories(CategoricalFunction, options=FunctionOptions.groupwise): ... +class GetCategories(CategoricalFunction): ... class IRCatNamespace(IRNamespace): diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 95783bcf45..6b6382496d 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -5,13 +5,17 @@ import sys from collections.abc import Iterable from decimal import Decimal +from operator import attrgetter from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload +from narwhals._plan._immutable import Immutable +from narwhals._plan.options import ExprIROptions, FEOptions, FunctionOptions from narwhals._plan.typing import ( Accessor, DTypeT, ExprIRT, ExprIRT2, + FunctionT, IRNamespaceT, MapIR, NamedOrExprIRT, @@ -25,9 +29,9 @@ if TYPE_CHECKING: from collections.abc import Iterator - from typing import Any, Callable, Literal + from typing import Any, Callable - from typing_extensions import Never, Self, TypeIs, dataclass_transform + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._plan import expr from narwhals._plan.dummy import Expr, Selector, Series @@ -41,37 +45,9 @@ WindowExpr, ) from narwhals._plan.meta import IRMetaNamespace - from narwhals._plan.options import FunctionOptions - from narwhals._plan.protocols import CompliantSeries + from narwhals._plan.protocols import CompliantSeries, Ctx, FrameT_contra, R_co from narwhals.typing import NonNestedDType, NonNestedLiteral -else: - # NOTE: This isn't important to the proposal, just wanted IDE support - # for the **temporary** constructors. - # It is interesting how much boilerplate this avoids though 🤔 - # https://docs.python.org/3/library/typing.html#typing.dataclass_transform - def dataclass_transform( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - frozen_default: bool = False, - field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), - **kwargs: Any, - ) -> Callable[[T], T]: - def decorator(cls_or_fn: T) -> T: - cls_or_fn.__dataclass_transform__ = { - "eq_default": eq_default, - "order_default": order_default, - "kw_only_default": kw_only_default, - "frozen_default": frozen_default, - "field_specifiers": field_specifiers, - "kwargs": kwargs, - } - return cls_or_fn - - return decorator - if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 @@ -87,116 +63,71 @@ def replace(obj: T, /, **changes: Any) -> T: T = TypeVar("T") +Incomplete: TypeAlias = "Any" -_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__" +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase, camelCase string to snake_case. -@dataclass_transform(kw_only_default=True, frozen_default=True) -class Immutable: - __slots__ = (_IMMUTABLE_HASH_NAME,) - __immutable_hash_value__: int + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - @property - def __immutable_keys__(self) -> Iterator[str]: - slots: tuple[str, ...] = self.__slots__ - for name in slots: - if name != _IMMUTABLE_HASH_NAME: - yield name - @property - def __immutable_values__(self) -> Iterator[Any]: - for name in self.__immutable_keys__: - yield getattr(self, name) +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - @property - def __immutable_items__(self) -> Iterator[tuple[str, Any]]: - for name in self.__immutable_keys__: - yield name, getattr(self, name) - @property - def __immutable_hash__(self) -> int: - if hasattr(self, _IMMUTABLE_HASH_NAME): - return self.__immutable_hash_value__ - hash_value = hash((self.__class__, *self.__immutable_values__)) - object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value) - return self.__immutable_hash_value__ - - def __setattr__(self, name: str, value: Never) -> Never: - msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set." - raise AttributeError(msg) - - def __replace__(self, **changes: Any) -> Self: - """https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415 - if len(changes) == 1: - k_new, v_new = next(iter(changes.items())) - # NOTE: Will trigger an attribute error if invalid name - if getattr(self, k_new) == v_new: - return self - changed = dict(self.__immutable_items__) - # Now we *don't* need to check the key is valid - changed[k_new] = v_new - else: - changed = dict(self.__immutable_items__) - changed |= changes - return type(self)(**changed) +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" - def __init_subclass__(cls, *args: Any, **kwds: Any) -> None: - super().__init_subclass__(*args, **kwds) - if cls.__slots__: - ... - else: - cls.__slots__ = () - - def __hash__(self) -> int: - return self.__immutable_hash__ - - def __eq__(self, other: object) -> bool: - if self is other: - return True - if type(self) is not type(other): - return False - return all( - getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__ - ) - - def __str__(self) -> str: - # NOTE: Debug repr, closer to constructor - fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__) - return f"{type(self).__name__}({fields})" - - def __init__(self, **kwds: Any) -> None: - # NOTE: DUMMY CONSTRUCTOR - don't use beyond prototyping! - # Just need a quick way to demonstrate `ExprIR` and interactions - required: set[str] = set(self.__immutable_keys__) - if not required and not kwds: - # NOTE: Fastpath for empty slots - ... - elif required == set(kwds): - # NOTE: Everything is as expected - for name, value in kwds.items(): - object.__setattr__(self, name, value) - elif missing := required.difference(kwds): - msg = ( - f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n" - f"but missing values for {sorted(missing)!r}" - ) - raise TypeError(msg) - else: - extra = set(kwds).difference(required) + +def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or _pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def _dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: + getter = attrgetter(_dispatch_method_name(tp)) + if tp.__expr_ir_config__.origin == "expr": + return getter + return lambda ctx: getter(ctx.__narwhals_namespace__()) + + +def _dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + if not tp.__expr_ir_config__.allow_dispatch: + + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: msg = ( - f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n" - f"but got unknown arguments {sorted(extra)!r}" + f"{tp.__name__!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" ) raise TypeError(msg) + return _ + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ -def _field_str(name: str, value: Any) -> str: - if isinstance(value, tuple): - inner = ", ".join(f"{v}" for v in value) - return f"{name}=[{inner}]" - if isinstance(value, str): - return f"{name}={value!r}" - return f"{name}={value}" + +def _dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = _dispatch_getter(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ class ExprIR(Immutable): @@ -205,10 +136,30 @@ class ExprIR(Immutable): _child: ClassVar[Seq[str]] = () """Nested node names, in iteration order.""" - def __init_subclass__(cls, *args: Any, child: Seq[str] = (), **kwds: Any) -> None: + __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] + ] + + def __init_subclass__( + cls: type[Self], + *args: Any, + child: Seq[str] = (), + config: ExprIROptions | None = None, + **kwds: Any, + ) -> None: super().__init_subclass__(*args, **kwds) if child: cls._child = child + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + ) -> R_co: + """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import dummy @@ -334,7 +285,7 @@ def _repr_html_(self) -> str: return self.__repr__() -class SelectorIR(ExprIR): +class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: from narwhals._plan import dummy @@ -418,7 +369,7 @@ def is_elementwise_top_level(self) -> bool: return ir.options.is_elementwise() if is_literal(ir): return ir.is_scalar - return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.Ternary, expr.Cast)) + return isinstance(ir, (expr.BinaryExpr, expr.Column, expr.TernaryExpr, expr.Cast)) class IRNamespace(Immutable): @@ -449,26 +400,19 @@ def _with_unary(self, function: Function, /) -> Expr: return self._expr._with_unary(function) -def _function_options_default() -> FunctionOptions: - from narwhals._plan.options import FunctionOptions - - return FunctionOptions.default() - - class Function(Immutable): """Shared by expr functions and namespace functions. - Only valid in `FunctionExpr.function` - https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 """ - _accessor: ClassVar[Accessor | None] = None - """Namespace accessor name, if any.""" - _function_options: ClassVar[staticmethod[[], FunctionOptions]] = staticmethod( - _function_options_default + FunctionOptions.default ) + __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() + __expr_ir_dispatch__: ClassVar[ + staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] + ] @property def function_options(self) -> FunctionOptions: @@ -484,45 +428,29 @@ def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: return FunctionExpr(input=inputs, function=self, options=self.function_options) def __init_subclass__( - cls, + cls: type[Self], *args: Any, accessor: Accessor | None = None, options: Callable[[], FunctionOptions] | None = None, + config: FEOptions | None = None, **kwds: Any, ) -> None: super().__init_subclass__(*args, **kwds) if accessor: - cls._accessor = accessor + config = replace(config or FEOptions.default(), accessor_name=accessor) if options: cls._function_options = staticmethod(options) + if config: + cls.__expr_ir_config__ = config + cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) def __repr__(self) -> str: - return _function_repr(type(self)) - + return _dispatch_method_name(type(self)) -# TODO @dangotbanned: Add caching strategy? -def _function_repr(tp: type[Function], /) -> str: - name = _pascal_to_snake_case(tp.__name__) - return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name - -def _pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" +class HorizontalFunction( + Function, options=FunctionOptions.horizontal, config=FEOptions.namespaced() +): ... _NON_NESTED_LITERAL_TPS = ( diff --git a/narwhals/_plan/demo.py b/narwhals/_plan/demo.py index 85b8ac2ad7..745309878d 100644 --- a/narwhals/_plan/demo.py +++ b/narwhals/_plan/demo.py @@ -13,7 +13,7 @@ from narwhals._plan.expr import All, Len from narwhals._plan.literal import ScalarLiteral, SeriesLiteral from narwhals._plan.ranges import IntRange -from narwhals._plan.strings import ConcatHorizontal +from narwhals._plan.strings import ConcatStr from narwhals._plan.when_then import When from narwhals._utils import Version, flatten @@ -121,7 +121,7 @@ def concat_str( ) -> Expr: it = parse.parse_into_seq_of_expr_ir(exprs, *more_exprs) return ( - ConcatHorizontal(separator=separator, ignore_nulls=ignore_nulls) + ConcatStr(separator=separator, ignore_nulls=ignore_nulls) .to_function_expr(*it) .to_narwhals() ) diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 704d3db612..ef120c2830 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -11,6 +11,7 @@ from narwhals._plan.common import ExprIR, SelectorIR, collect from narwhals._plan.exceptions import function_expr_invalid_operation_error from narwhals._plan.name import KeepName, RenameAlias +from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT, LeftSelectorT, @@ -37,6 +38,7 @@ from narwhals._plan.functions import MapBatches # noqa: F401 from narwhals._plan.literal import LiteralValue from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions + from narwhals._plan.protocols import Ctx, FrameT_contra, R_co from narwhals._plan.selectors import Selector from narwhals._plan.window import Window from narwhals.dtypes import DType @@ -66,7 +68,7 @@ "SelectorIR", "Sort", "SortBy", - "Ternary", + "TernaryExpr", "WindowExpr", "col", ] @@ -88,7 +90,7 @@ def index_columns(*indices: int) -> IndexColumns: return IndexColumns(indices=indices) -class Alias(ExprIR, child=("expr",)): +class Alias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "name") expr: ExprIR name: str @@ -107,7 +109,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: return common.replace(self, expr=expr) -class Column(ExprIR): +class Column(ExprIR, config=ExprIROptions.namespaced("col")): __slots__ = ("name",) name: str @@ -118,7 +120,7 @@ def with_name(self, name: str, /) -> Column: return common.replace(self, name=name) -class _ColumnSelection(ExprIR): +class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): """Nodes which can resolve to `Column`(s) with a `Schema`.""" @@ -173,7 +175,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: return common.replace(self, expr=expr) -class Literal(ExprIR, t.Generic[LiteralT]): +class Literal(ExprIR, t.Generic[LiteralT], config=ExprIROptions.namespaced("lit")): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L81.""" __slots__ = ("value",) @@ -376,13 +378,25 @@ def __init__( raise function_expr_invalid_operation_error(function, parent) super().__init__(**dict(input=input, function=function, options=options, **kwds)) + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RollingExpr(FunctionExpr[RollingT]): ... -class AnonymousExpr(FunctionExpr["MapBatches"]): +class AnonymousExpr( + FunctionExpr["MapBatches"], config=ExprIROptions.renamed("map_batches") +): """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" + def dispatch( + self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + ) -> R_co: + return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + class RangeExpr(FunctionExpr[RangeT]): """E.g. `int_range(...)`. @@ -434,7 +448,9 @@ def map_ir(self, function: MapIR, /) -> ExprIR: return function(changed) -class WindowExpr(ExprIR, child=("expr", "partition_by")): +class WindowExpr( + ExprIR, child=("expr", "partition_by"), config=ExprIROptions.renamed("over") +): """A fully specified `.over()`, that occurred after another expression. Related: @@ -468,7 +484,11 @@ def with_partition_by(self, partition_by: t.Iterable[ExprIR], /) -> Self: return common.replace(self, partition_by=collect(partition_by)) -class OrderedWindowExpr(WindowExpr, child=("expr", "partition_by", "order_by")): +class OrderedWindowExpr( + WindowExpr, + child=("expr", "partition_by", "order_by"), + config=ExprIROptions.renamed("over_ordered"), +): __slots__ = ("expr", "partition_by", "order_by", "sort_options", "options") # noqa: RUF023 expr: ExprIR partition_by: Seq[ExprIR] @@ -505,7 +525,7 @@ def with_order_by(self, order_by: t.Iterable[ExprIR], /) -> Self: return common.replace(self, order_by=collect(order_by)) -class Len(ExprIR): +class Len(ExprIR, config=ExprIROptions.namespaced()): @property def is_scalar(self) -> bool: return True @@ -555,7 +575,7 @@ def matches_column(self, name: str, dtype: DType) -> bool: return not self.selector.matches_column(name, dtype) -class Ternary(ExprIR, child=("truthy", "falsy", "predicate")): +class TernaryExpr(ExprIR, child=("truthy", "falsy", "predicate")): """When-Then-Otherwise.""" __slots__ = ("truthy", "falsy", "predicate") # noqa: RUF023 diff --git a/narwhals/_plan/expr_expansion.py b/narwhals/_plan/expr_expansion.py index d7ab345f81..86073ef074 100644 --- a/narwhals/_plan/expr_expansion.py +++ b/narwhals/_plan/expr_expansion.py @@ -44,13 +44,8 @@ from typing import TYPE_CHECKING from narwhals._plan import common -from narwhals._plan.common import ( - ExprIR, - Immutable, - NamedIR, - SelectorIR, - is_horizontal_reduction, -) +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import ExprIR, NamedIR, SelectorIR, is_horizontal_reduction from narwhals._plan.exceptions import ( column_index_error, column_not_found_error, diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 4c80849a89..b6df6d94ca 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Function +from narwhals._plan.common import Function, HorizontalFunction from narwhals._plan.exceptions import hist_bins_monotonic_error from narwhals._plan.options import FunctionFlags, FunctionOptions @@ -24,7 +24,7 @@ class Abs(Function, options=FunctionOptions.elementwise): ... -class Hist(Function, options=FunctionOptions.groupwise): +class Hist(Function): """Only supported for `Series` so far.""" __slots__ = ("include_breakpoint",) @@ -105,13 +105,13 @@ class Shift(Function, options=FunctionOptions.length_preserving): class DropNulls(Function, options=FunctionOptions.row_separable): ... -class Mode(Function, options=FunctionOptions.groupwise): ... +class Mode(Function): ... class Skew(Function, options=FunctionOptions.aggregation): ... -class Rank(Function, options=FunctionOptions.groupwise): +class Rank(Function): __slots__ = ("options",) options: RankOptions @@ -165,7 +165,7 @@ class RollingStd(RollingWindow): ... class Diff(Function, options=FunctionOptions.length_preserving): ... -class Unique(Function, options=FunctionOptions.groupwise): ... +class Unique(Function): ... class Round(Function, options=FunctionOptions.elementwise): @@ -173,16 +173,16 @@ class Round(Function, options=FunctionOptions.elementwise): decimals: int -class SumHorizontal(Function, options=FunctionOptions.horizontal): ... +class SumHorizontal(HorizontalFunction): ... -class MinHorizontal(Function, options=FunctionOptions.horizontal): ... +class MinHorizontal(HorizontalFunction): ... -class MaxHorizontal(Function, options=FunctionOptions.horizontal): ... +class MaxHorizontal(HorizontalFunction): ... -class MeanHorizontal(Function, options=FunctionOptions.horizontal): ... +class MeanHorizontal(HorizontalFunction): ... class EwmMean(Function, options=FunctionOptions.length_preserving): @@ -197,7 +197,7 @@ class ReplaceStrict(Function, options=FunctionOptions.elementwise): return_dtype: DType | None -class GatherEvery(Function, options=FunctionOptions.groupwise): +class GatherEvery(Function): __slots__ = ("n", "offset") n: int offset: int diff --git a/narwhals/_plan/literal.py b/narwhals/_plan/literal.py index 1dc443fcef..6b51a87f25 100644 --- a/narwhals/_plan/literal.py +++ b/narwhals/_plan/literal.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any, Generic -from narwhals._plan.common import Immutable, is_literal +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import is_literal from narwhals._plan.typing import LiteralT, NativeSeriesT, NonNestedLiteralT if TYPE_CHECKING: diff --git a/narwhals/_plan/name.py b/narwhals/_plan/name.py index 88b8f5ec35..b722a2c1da 100644 --- a/narwhals/_plan/name.py +++ b/narwhals/_plan/name.py @@ -3,7 +3,9 @@ from typing import TYPE_CHECKING from narwhals._plan import common -from narwhals._plan.common import ExprIR, ExprNamespace, Immutable, IRNamespace +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import ExprIR +from narwhals._plan.options import ExprIROptions if TYPE_CHECKING: from typing_extensions import Self @@ -13,7 +15,7 @@ from narwhals._plan.typing import MapIR -class KeepName(ExprIR, child=("expr",)): +class KeepName(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr",) expr: ExprIR @@ -31,7 +33,7 @@ def with_expr(self, expr: ExprIR, /) -> Self: return common.replace(self, expr=expr) -class RenameAlias(ExprIR, child=("expr",)): +class RenameAlias(ExprIR, child=("expr",), config=ExprIROptions.no_dispatch()): __slots__ = ("expr", "function") expr: ExprIR function: AliasName @@ -66,7 +68,7 @@ def __call__(self, name: str, /) -> str: return f"{name}{self.suffix}" -class IRNameNamespace(IRNamespace): +class IRNameNamespace(common.IRNamespace): def keep(self) -> KeepName: return KeepName(expr=self._ir) @@ -86,7 +88,7 @@ def to_uppercase(self) -> RenameAlias: return self.map(str.upper) -class ExprNameNamespace(ExprNamespace[IRNameNamespace]): +class ExprNameNamespace(common.ExprNamespace[IRNameNamespace]): @property def _ir_namespace(self) -> type[IRNameNamespace]: return IRNameNamespace diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 38c912c1bd..cc5d89befd 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -3,7 +3,8 @@ import operator as op from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import is_function_expr from narwhals._plan.exceptions import ( binary_expr_length_changing_error, binary_expr_multi_output_error, diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index ca6cf91a04..f4eb9cd2df 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -4,16 +4,19 @@ from itertools import repeat from typing import TYPE_CHECKING, Literal -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable if TYPE_CHECKING: from collections.abc import Iterable, Sequence import pyarrow.compute as pc + from typing_extensions import Self, TypeAlias - from narwhals._plan.typing import Seq + from narwhals._plan.typing import Accessor, Seq from narwhals.typing import RankMethod +DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] + class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -263,3 +266,56 @@ def rolling_options( center=center, fn_params=ddof if ddof is None else RollingVarParams(ddof=ddof), ) + + +class _BaseIROptions(Immutable): + __slots__ = ("origin", "override_name") + origin: DispatchOrigin + override_name: str + + def __repr__(self) -> str: + return self.__str__() + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="") + + @classmethod + def renamed(cls, name: str, /) -> Self: + from narwhals._plan.common import replace + + return replace(cls.default(), override_name=name) + + @classmethod + def namespaced(cls, override_name: str = "", /) -> Self: + from narwhals._plan.common import replace + + return replace( + cls.default(), origin="__narwhals_namespace__", override_name=override_name + ) + + +class ExprIROptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "allow_dispatch") + allow_dispatch: bool + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", allow_dispatch=True) + + @staticmethod + def no_dispatch() -> ExprIROptions: + return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + + +class FunctionExprOptions(_BaseIROptions): + __slots__ = (*_BaseIROptions.__slots__, "accessor_name") + accessor_name: Accessor | None + """Namespace accessor name, if any.""" + + @classmethod + def default(cls) -> Self: + return cls(origin="expr", override_name="", accessor_name=None) + + +FEOptions = FunctionExprOptions diff --git a/narwhals/_plan/protocols.py b/narwhals/_plan/protocols.py index 821e7a338a..f49635d46a 100644 --- a/narwhals/_plan/protocols.py +++ b/narwhals/_plan/protocols.py @@ -1,10 +1,9 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence, Sized -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from typing import TYPE_CHECKING, Any, Literal, Protocol, overload -from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings -from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe +from narwhals._plan.common import ExprIR, NamedIR, flatten_hash_safe from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq from narwhals._typing_compat import TypeVar from narwhals._utils import Version, _hasattr_static @@ -12,10 +11,12 @@ if TYPE_CHECKING: from typing_extensions import Self, TypeAlias, TypeIs + from narwhals._plan import aggregation as agg, boolean, expr, functions as F from narwhals._plan.dummy import BaseFrame, DataFrame, Series from narwhals._plan.expr import FunctionExpr, RangeExpr from narwhals._plan.options import SortMultipleOptions from narwhals._plan.ranges import IntRange + from narwhals._plan.strings import ConcatStr from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import ( @@ -69,6 +70,22 @@ LazyExprT_co = TypeVar("LazyExprT_co", bound=LazyExprAny, covariant=True) LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True) +Ctx: TypeAlias = "ExprDispatch[FrameT_contra, R_co, NamespaceAny]" +"""Type of an unknown expression dispatch context. + +- `FrameT_contra`: Compliant data/lazyframe +- `R_co`: Upper bound return type of the context +""" + + +class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]): + def __narwhals_namespace__(self) -> NamespaceT_co: ... + + +def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co: + """Return the compliant namespace.""" + return obj.__narwhals_namespace__() + # NOTE: Unlike the version in `nw._utils`, here `.version` it is public class StoresVersion(Protocol): @@ -143,130 +160,11 @@ def _length_required( class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): - _DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = { - expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col( - node, frame, name - ), - expr.Literal: lambda self, node, frame, name: self.__narwhals_namespace__().lit( - node, frame, name - ), - expr.Len: lambda self, node, frame, name: self.__narwhals_namespace__().len( - node, frame, name - ), - expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name), - expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name), - expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name), - expr.Filter: lambda self, node, frame, name: self.filter(node, frame, name), - agg.First: lambda self, node, frame, name: self.first(node, frame, name), - agg.Last: lambda self, node, frame, name: self.last(node, frame, name), - agg.ArgMin: lambda self, node, frame, name: self.arg_min(node, frame, name), - agg.ArgMax: lambda self, node, frame, name: self.arg_max(node, frame, name), - agg.Sum: lambda self, node, frame, name: self.sum(node, frame, name), - agg.NUnique: lambda self, node, frame, name: self.n_unique(node, frame, name), - agg.Std: lambda self, node, frame, name: self.std(node, frame, name), - agg.Var: lambda self, node, frame, name: self.var(node, frame, name), - agg.Quantile: lambda self, node, frame, name: self.quantile(node, frame, name), - agg.Count: lambda self, node, frame, name: self.count(node, frame, name), - agg.Max: lambda self, node, frame, name: self.max(node, frame, name), - agg.Mean: lambda self, node, frame, name: self.mean(node, frame, name), - agg.Median: lambda self, node, frame, name: self.median(node, frame, name), - agg.Min: lambda self, node, frame, name: self.min(node, frame, name), - expr.BinaryExpr: lambda self, node, frame, name: self.binary_expr( - node, frame, name - ), - expr.RollingExpr: lambda self, node, frame, name: self.rolling_expr( - node, frame, name - ), - expr.AnonymousExpr: lambda self, node, frame, name: self.map_batches( - node, frame, name - ), - expr.FunctionExpr: lambda self, node, frame, name: self._dispatch_function( - node, frame, name - ), - # NOTE: Keeping it simple for now - # When adding other `*_range` functions, this should instead map to `range_expr` - expr.RangeExpr: lambda self, - node, - frame, - name: self.__narwhals_namespace__().int_range(node, frame, name), - expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered( - node, frame, name - ), - expr.WindowExpr: lambda self, node, frame, name: self.over(node, frame, name), - expr.Ternary: lambda self, node, frame, name: self.ternary_expr( - node, frame, name - ), - } - _DISPATCH_FUNCTION: ClassVar[ - Mapping[type[Function], Callable[[Any, FunctionExpr, Any, str], Any]] - ] = { - boolean.AnyHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().any_horizontal(node, frame, name), - boolean.AllHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().all_horizontal(node, frame, name), - F.SumHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().sum_horizontal(node, frame, name), - F.MinHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().min_horizontal(node, frame, name), - F.MaxHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().max_horizontal(node, frame, name), - F.MeanHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().mean_horizontal(node, frame, name), - strings.ConcatHorizontal: lambda self, - node, - frame, - name: self.__narwhals_namespace__().concat_str(node, frame, name), - F.Pow: lambda self, node, frame, name: self.pow(node, frame, name), - F.FillNull: lambda self, node, frame, name: self.fill_null(node, frame, name), - boolean.IsBetween: lambda self, node, frame, name: self.is_between( - node, frame, name - ), - boolean.IsFinite: lambda self, node, frame, name: self.is_finite( - node, frame, name - ), - boolean.IsNan: lambda self, node, frame, name: self.is_nan(node, frame, name), - boolean.IsNull: lambda self, node, frame, name: self.is_null(node, frame, name), - boolean.Not: lambda self, node, frame, name: self.not_(node, frame, name), - boolean.Any: lambda self, node, frame, name: self.any(node, frame, name), - boolean.All: lambda self, node, frame, name: self.all(node, frame, name), - } - - def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: - if (method := self._DISPATCH.get(node.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - - def _dispatch_function( - self, node: FunctionExpr, frame: FrameT_contra, name: str - ) -> R_co: - fn = node.function - if (method := self._DISPATCH_FUNCTION.get(fn.__class__)) and ( - result := method(self, node, frame, name) - ): - return result # type: ignore[no-any-return] - msg = f"Support for {fn.__class__.__name__!r} is not yet implemented, got:\n{node!r}" - raise NotImplementedError(msg) - @classmethod def from_ir(cls, node: ExprIR, frame: FrameT_contra, name: str) -> R_co: obj = cls.__new__(cls) obj._version = frame.version - return obj._dispatch(node, frame, name) + return node.dispatch(obj, frame, name) @classmethod def from_named_ir(cls, named_ir: NamedIR[ExprIR], frame: FrameT_contra) -> R_co: @@ -318,7 +216,7 @@ def binary_expr( self, node: expr.BinaryExpr, frame: FrameT_contra, name: str ) -> Self: ... def ternary_expr( - self, node: expr.Ternary, frame: FrameT_contra, name: str + self, node: expr.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... def over(self, node: expr.WindowExpr, frame: FrameT_contra, name: str) -> Self: ... # NOTE: `Scalar` is returned **only** for un-partitioned `OrderableAggExpr` @@ -571,7 +469,7 @@ def mean_horizontal( self, node: FunctionExpr[F.MeanHorizontal], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def concat_str( - self, node: FunctionExpr[strings.ConcatHorizontal], frame: FrameT, name: str + self, node: FunctionExpr[ConcatStr], frame: FrameT, name: str ) -> ExprT_co | ScalarT_co: ... def int_range( self, node: RangeExpr[IntRange], frame: FrameT, name: str @@ -704,12 +602,10 @@ class EagerDataFrame( ): def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ... def select(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) def with_columns(self, irs: Seq[NamedIR]) -> Self: - ns = self.__narwhals_namespace__() - return ns._concat_horizontal(self._evaluate_irs(irs)) + return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs)) class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]): diff --git a/narwhals/_plan/ranges.py b/narwhals/_plan/ranges.py index 4414afabf7..4f8e49b531 100644 --- a/narwhals/_plan/ranges.py +++ b/narwhals/_plan/ranges.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprIR, Function -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing_extensions import Self @@ -12,7 +12,7 @@ from narwhals.dtypes import IntegerType -class RangeFunction(Function): +class RangeFunction(Function, config=FEOptions.namespaced()): def to_function_expr(self, *inputs: ExprIR) -> RangeExpr[Self]: from narwhals._plan.expr import RangeExpr diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 17b8416285..fea8827123 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -7,7 +7,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, TypeVar, overload -from narwhals._plan.common import _IMMUTABLE_HASH_NAME, Immutable, NamedIR +from narwhals._plan._immutable import _IMMUTABLE_HASH_NAME, Immutable +from narwhals._plan.common import NamedIR from narwhals.dtypes import Unknown if TYPE_CHECKING: diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index 3cd7666ddc..37de3f118e 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -9,7 +9,8 @@ import re from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, flatten_hash_safe +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import flatten_hash_safe from narwhals._utils import Version, _parse_time_unit_and_time_zone if TYPE_CHECKING: diff --git a/narwhals/_plan/strings.py b/narwhals/_plan/strings.py index 8a1789b079..4c03d2614d 100644 --- a/narwhals/_plan/strings.py +++ b/narwhals/_plan/strings.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import ExprNamespace, Function, IRNamespace +from narwhals._plan.common import ExprNamespace, Function, HorizontalFunction, IRNamespace from narwhals._plan.options import FunctionOptions if TYPE_CHECKING: @@ -12,9 +12,7 @@ class StringFunction(Function, accessor="str", options=FunctionOptions.elementwise): ... -class ConcatHorizontal(StringFunction, options=FunctionOptions.horizontal): - """`nw.functions.concat_str`.""" - +class ConcatStr(HorizontalFunction, StringFunction): __slots__ = ("ignore_nulls", "separator") separator: str ignore_nulls: bool diff --git a/narwhals/_plan/struct.py b/narwhals/_plan/struct.py index d91fef6458..7e1ed69ac3 100644 --- a/narwhals/_plan/struct.py +++ b/narwhals/_plan/struct.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan.common import ExprNamespace, Function, IRNamespace -from narwhals._plan.options import FunctionOptions +from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from narwhals._plan.dummy import Expr @@ -12,9 +12,9 @@ class StructFunction(Function, accessor="struct"): ... -class FieldByName(StructFunction, options=FunctionOptions.elementwise): - """https://github.com/pola-rs/polars/blob/62257860a43ec44a638e8492ed2cf98a49c05f2e/crates/polars-plan/src/dsl/function_expr/struct_.rs#L11.""" - +class FieldByName( + StructFunction, options=FunctionOptions.elementwise, config=FEOptions.renamed("field") +): __slots__ = ("name",) name: str diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index d264f39733..4e0b14040e 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan.common import Immutable, is_expr +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import is_expr from narwhals._plan.dummy import Expr from narwhals._plan.expr_parsing import ( parse_into_expr_ir, @@ -13,7 +14,7 @@ from collections.abc import Iterable from narwhals._plan.common import ExprIR - from narwhals._plan.expr import Ternary + from narwhals._plan.expr import TernaryExpr from narwhals._plan.typing import IntoExpr, IntoExprColumn, Seq @@ -116,7 +117,7 @@ def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] return super().__eq__(value) -def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary: - from narwhals._plan.expr import Ternary +def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: + from narwhals._plan.expr import TernaryExpr - return Ternary(predicate=predicate, truthy=truthy, falsy=falsy) + return TernaryExpr(predicate=predicate, truthy=truthy, falsy=falsy) diff --git a/narwhals/_plan/window.py b/narwhals/_plan/window.py index f575d9d303..af414db989 100644 --- a/narwhals/_plan/window.py +++ b/narwhals/_plan/window.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING -from narwhals._plan.common import Immutable, is_function_expr, is_window_expr +from narwhals._plan._immutable import Immutable +from narwhals._plan.common import is_function_expr, is_window_expr from narwhals._plan.exceptions import ( over_elementwise_error, over_nested_error, diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 3c5e97439e..6f9d0450ad 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -6,7 +6,7 @@ import pytest -from narwhals._plan.common import Immutable +from narwhals._plan._immutable import Immutable class Empty(Immutable): ...