From fc69b38a43797fd1ca74a41e8273543c1e7d5e54 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 31 Aug 2025 18:53:52 +0000 Subject: [PATCH] refactor(expr-ir): Shrink `_plan.operators` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Looking *quite* dense now 😄 --- narwhals/_plan/operators.py | 115 ++++++++++-------------------------- 1 file changed, 31 insertions(+), 84 deletions(-) diff --git a/narwhals/_plan/operators.py b/narwhals/_plan/operators.py index 09d072e7bd..38c912c1bd 100644 --- a/narwhals/_plan/operators.py +++ b/narwhals/_plan/operators.py @@ -1,6 +1,6 @@ from __future__ import annotations -import operator +import operator as op from typing import TYPE_CHECKING from narwhals._plan.common import Immutable, is_function_expr @@ -9,7 +9,6 @@ binary_expr_multi_output_error, binary_expr_shape_error, ) -from narwhals._plan.expr import BinarySelector if TYPE_CHECKING: from typing import Any, ClassVar @@ -28,30 +27,19 @@ class Operator(Immutable): - _op: ClassVar[OperatorFn] + _func: ClassVar[OperatorFn] + _symbol: ClassVar[str] def __repr__(self) -> str: - tp = type(self) - if tp in {Operator, SelectorOperator}: - return tp.__name__ - m = { - Eq: "==", - NotEq: "!=", - Lt: "<", - LtEq: "<=", - Gt: ">", - GtEq: ">=", - Add: "+", - Sub: "-", - Multiply: "*", - TrueDivide: "/", - FloorDivide: "//", - Modulus: "%", - And: "&", - Or: "|", - ExclusiveOr: "^", - } - return m[tp] + return self._symbol + + def __init_subclass__( + cls, *args: Any, func: OperatorFn | None, symbol: str = "", **kwds: Any + ) -> None: + super().__init_subclass__(*args, **kwds) + if func: + cls._func = func + cls._symbol = symbol or cls.__name__ def to_binary_expr( self, left: LeftT, right: RightT, / @@ -72,7 +60,7 @@ def to_binary_expr( def __call__(self, lhs: Any, rhs: Any) -> Any: """Apply binary operator to `left`, `right` operands.""" - return self.__class__._op(lhs, rhs) + return self.__class__._func(lhs, rhs) def _is_filtration(ir: ExprIR) -> bool: @@ -81,7 +69,7 @@ def _is_filtration(ir: ExprIR) -> bool: return False -class SelectorOperator(Operator): +class SelectorOperator(Operator, func=None): """Operators that can *also* be used in selectors.""" def to_binary_selector( @@ -92,61 +80,20 @@ def to_binary_selector( return BinarySelector(left=left, op=self, right=right) -class Eq(Operator): - _op = operator.eq - - -class NotEq(Operator): - _op = operator.ne - - -class Lt(Operator): - _op = operator.le - - -class LtEq(Operator): - _op = operator.lt - - -class Gt(Operator): - _op = operator.gt - - -class GtEq(Operator): - _op = operator.ge - - -class Add(Operator): - _op = operator.add - - -class Sub(SelectorOperator): - _op = operator.sub - - -class Multiply(Operator): - _op = operator.mul - - -class TrueDivide(Operator): - _op = operator.truediv - - -class FloorDivide(Operator): - _op = operator.floordiv - - -class Modulus(Operator): - _op = operator.mod - - -class And(SelectorOperator): - _op = operator.and_ - - -class Or(SelectorOperator): - _op = operator.or_ - - -class ExclusiveOr(SelectorOperator): - _op = operator.xor +# fmt: off +class Eq(Operator, func=op.eq, symbol="=="): ... +class NotEq(Operator, func=op.ne, symbol="!="): ... +class Lt(Operator, func=op.le, symbol="<"): ... +class LtEq(Operator, func=op.lt, symbol="<="): ... +class Gt(Operator, func=op.gt, symbol=">"): ... +class GtEq(Operator, func=op.ge, symbol=">="): ... +class Add(Operator, func=op.add, symbol="+"): ... +class Sub(SelectorOperator, func=op.sub, symbol="-"): ... +class Multiply(Operator, func=op.mul, symbol="*"): ... +class TrueDivide(Operator, func=op.truediv, symbol="/"): ... +class FloorDivide(Operator, func=op.floordiv, symbol="//"): ... +class Modulus(Operator, func=op.mod, symbol="%"): ... +class And(SelectorOperator, func=op.and_, symbol="&"): ... +class Or(SelectorOperator, func=op.or_, symbol="|"): ... +class ExclusiveOr(SelectorOperator, func=op.xor, symbol="^"): ... +# fmt: on