Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 31 additions & 84 deletions narwhals/_plan/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import operator
import operator as op
from typing import TYPE_CHECKING

from narwhals._plan.common import Immutable, is_function_expr
Expand All @@ -9,7 +9,6 @@
binary_expr_multi_output_error,
binary_expr_shape_error,
)
from narwhals._plan.expr import BinarySelector

if TYPE_CHECKING:
from typing import Any, ClassVar
Expand All @@ -28,30 +27,19 @@


class Operator(Immutable):
_op: ClassVar[OperatorFn]
_func: ClassVar[OperatorFn]
_symbol: ClassVar[str]

def __repr__(self) -> str:
tp = type(self)
if tp in {Operator, SelectorOperator}:
return tp.__name__
m = {
Eq: "==",
NotEq: "!=",
Lt: "<",
LtEq: "<=",
Gt: ">",
GtEq: ">=",
Add: "+",
Sub: "-",
Multiply: "*",
TrueDivide: "/",
FloorDivide: "//",
Modulus: "%",
And: "&",
Or: "|",
ExclusiveOr: "^",
}
return m[tp]
return self._symbol

def __init_subclass__(
cls, *args: Any, func: OperatorFn | None, symbol: str = "", **kwds: Any
) -> None:
super().__init_subclass__(*args, **kwds)
if func:
cls._func = func
cls._symbol = symbol or cls.__name__

def to_binary_expr(
self, left: LeftT, right: RightT, /
Expand All @@ -72,7 +60,7 @@ def to_binary_expr(

def __call__(self, lhs: Any, rhs: Any) -> Any:
"""Apply binary operator to `left`, `right` operands."""
return self.__class__._op(lhs, rhs)
return self.__class__._func(lhs, rhs)


def _is_filtration(ir: ExprIR) -> bool:
Expand All @@ -81,7 +69,7 @@ def _is_filtration(ir: ExprIR) -> bool:
return False


class SelectorOperator(Operator):
class SelectorOperator(Operator, func=None):
"""Operators that can *also* be used in selectors."""

def to_binary_selector(
Expand All @@ -92,61 +80,20 @@ def to_binary_selector(
return BinarySelector(left=left, op=self, right=right)


class Eq(Operator):
_op = operator.eq


class NotEq(Operator):
_op = operator.ne


class Lt(Operator):
_op = operator.le


class LtEq(Operator):
_op = operator.lt


class Gt(Operator):
_op = operator.gt


class GtEq(Operator):
_op = operator.ge


class Add(Operator):
_op = operator.add


class Sub(SelectorOperator):
_op = operator.sub


class Multiply(Operator):
_op = operator.mul


class TrueDivide(Operator):
_op = operator.truediv


class FloorDivide(Operator):
_op = operator.floordiv


class Modulus(Operator):
_op = operator.mod


class And(SelectorOperator):
_op = operator.and_


class Or(SelectorOperator):
_op = operator.or_


class ExclusiveOr(SelectorOperator):
_op = operator.xor
# fmt: off
class Eq(Operator, func=op.eq, symbol="=="): ...
class NotEq(Operator, func=op.ne, symbol="!="): ...
class Lt(Operator, func=op.le, symbol="<"): ...
class LtEq(Operator, func=op.lt, symbol="<="): ...
class Gt(Operator, func=op.gt, symbol=">"): ...
class GtEq(Operator, func=op.ge, symbol=">="): ...
class Add(Operator, func=op.add, symbol="+"): ...
class Sub(SelectorOperator, func=op.sub, symbol="-"): ...
class Multiply(Operator, func=op.mul, symbol="*"): ...
class TrueDivide(Operator, func=op.truediv, symbol="/"): ...
class FloorDivide(Operator, func=op.floordiv, symbol="//"): ...
class Modulus(Operator, func=op.mod, symbol="%"): ...
class And(SelectorOperator, func=op.and_, symbol="&"): ...
class Or(SelectorOperator, func=op.or_, symbol="|"): ...
class ExclusiveOr(SelectorOperator, func=op.xor, symbol="^"): ...
# fmt: on
Loading