diff --git a/python/pyiceberg/expressions/base.py b/python/pyiceberg/expressions/base.py index f91093b17ad4..f4584fdb4ca5 100644 --- a/python/pyiceberg/expressions/base.py +++ b/python/pyiceberg/expressions/base.py @@ -32,50 +32,43 @@ class BooleanExpression(ABC): - """Represents a boolean expression tree.""" + """An expression that evaluates to a boolean""" @abstractmethod def __invert__(self) -> BooleanExpression: """Transform the Expression into its negated version.""" -class Bound(Generic[T], ABC): - """Represents a bound value expression.""" +class Term(Generic[T], ABC): + """A simple expression that evaluates to a value""" - def eval(self, struct: StructProtocol): # pylint: disable=W0613 - ... # pragma: no cover + +class Bound(ABC): + """Represents a bound value expression""" -class Unbound(Generic[T, B], ABC): - """Represents an unbound expression node.""" +class Unbound(Generic[B], ABC): + """Represents an unbound value expression""" @abstractmethod - def bind(self, schema: Schema, case_sensitive: bool) -> B: + def bind(self, schema: Schema, case_sensitive: bool = True) -> B: ... # pragma: no cover -class Term(ABC): - """An expression that evaluates to a value.""" - - -class BaseReference(Generic[T], Term, ABC): - """Represents a variable reference in an expression.""" - - -class BoundTerm(Bound[T], Term): - """Represents a bound term.""" +class BoundTerm(Term[T], Bound, ABC): + """Represents a bound term""" @abstractmethod def ref(self) -> BoundReference[T]: ... - -class UnboundTerm(Unbound[T, BoundTerm[T]], Term): - """Represents an unbound term.""" + @abstractmethod + def eval(self, struct: StructProtocol): # pylint: disable=W0613 + ... # pragma: no cover @dataclass(frozen=True) -class BoundReference(BoundTerm[T], BaseReference[T]): +class BoundReference(BoundTerm[T]): """A reference bound to a field in a schema Args: @@ -88,6 +81,7 @@ class BoundReference(BoundTerm[T], BaseReference[T]): def eval(self, struct: StructProtocol) -> T: """Returns the value at the referenced field's position in an object that abides by the StructProtocol + Args: struct (StructProtocol): A row object that abides by the StructProtocol and returns values given a position Returns: @@ -99,8 +93,12 @@ def ref(self) -> BoundReference[T]: return self +class UnboundTerm(Term[T], Unbound[BoundTerm[T]], ABC): + """Represents an unbound term.""" + + @dataclass(frozen=True) -class Reference(UnboundTerm[T], BaseReference[T]): +class Reference(UnboundTerm[T]): """A reference not yet bound to a field in a schema Args: @@ -112,7 +110,7 @@ class Reference(UnboundTerm[T], BaseReference[T]): name: str - def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]: + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[T]: """Bind the reference to an Iceberg schema Args: @@ -125,22 +123,24 @@ def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]: Returns: BoundReference: A reference bound to the specific field in the Iceberg schema """ - field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) # pylint: disable=redefined-outer-name - + field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) if not field: raise ValueError(f"Cannot find field '{self.name}' in schema: {schema}") accessor = schema.accessor_for_field(field.field_id) - if not accessor: raise ValueError(f"Cannot find accessor for field '{self.name}' in schema: {schema}") return BoundReference(field=field, accessor=accessor) +@dataclass(frozen=True, init=False) class And(BooleanExpression): """AND operation expression - logical conjunction""" + left: BooleanExpression + right: BooleanExpression + def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression): if rest: return reduce(And, (left, right, *rest)) @@ -150,35 +150,23 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole return right elif right is AlwaysTrue(): return left - self = super().__new__(cls) - self._left = left # type: ignore - self._right = right # type: ignore - return self - - @property - def left(self) -> BooleanExpression: - return self._left # type: ignore - - @property - def right(self) -> BooleanExpression: - return self._right # type: ignore - - def __eq__(self, other) -> bool: - return id(self) == id(other) or (isinstance(other, And) and self.left == other.left and self.right == other.right) + else: + result = super().__new__(cls) + object.__setattr__(result, "left", left) + object.__setattr__(result, "right", right) + return result def __invert__(self) -> Or: return Or(~self.left, ~self.right) - def __repr__(self) -> str: - return f"And({repr(self.left)}, {repr(self.right)})" - - def __str__(self) -> str: - return f"And({str(self.left)}, {str(self.right)})" - +@dataclass(frozen=True, init=False) class Or(BooleanExpression): """OR operation expression - logical disjunction""" + left: BooleanExpression + right: BooleanExpression + def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression): if rest: return reduce(Or, (left, right, *rest)) @@ -188,35 +176,22 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole return right elif right is AlwaysFalse(): return left - self = super().__new__(cls) - self._left = left # type: ignore - self._right = right # type: ignore - return self - - @property - def left(self) -> BooleanExpression: - return self._left # type: ignore - - @property - def right(self) -> BooleanExpression: - return self._right # type: ignore - - def __eq__(self, other) -> bool: - return id(self) == id(other) or (isinstance(other, Or) and self.left == other.left and self.right == other.right) + else: + result = super().__new__(cls) + object.__setattr__(result, "left", left) + object.__setattr__(result, "right", right) + return result def __invert__(self) -> And: return And(~self.left, ~self.right) - def __repr__(self) -> str: - return f"Or({repr(self.left)}, {repr(self.right)})" - - def __str__(self) -> str: - return f"Or({str(self.left)}, {str(self.right)})" - +@dataclass(frozen=True, init=False) class Not(BooleanExpression): """NOT operation expression - logical negation""" + child: BooleanExpression + def __new__(cls, child: BooleanExpression): if child is AlwaysTrue(): return AlwaysFalse() @@ -224,23 +199,13 @@ def __new__(cls, child: BooleanExpression): return AlwaysTrue() elif isinstance(child, Not): return child.child - return super().__new__(cls) - - def __init__(self, child): - self.child = child - - def __eq__(self, other) -> bool: - return id(self) == id(other) or (isinstance(other, Not) and self.child == other.child) + result = super().__new__(cls) + object.__setattr__(result, "child", child) + return result def __invert__(self) -> BooleanExpression: return self.child - def __repr__(self) -> str: - return f"Not({repr(self.child)})" - - def __str__(self) -> str: - return f"Not({str(self.child)})" - @dataclass(frozen=True) class AlwaysTrue(BooleanExpression, Singleton): @@ -259,7 +224,7 @@ def __invert__(self) -> AlwaysTrue: @dataclass(frozen=True) -class BoundPredicate(Bound[T], BooleanExpression): +class BoundPredicate(Generic[T], Bound, BooleanExpression): term: BoundTerm[T] def __invert__(self) -> BoundPredicate[T]: @@ -267,7 +232,7 @@ def __invert__(self) -> BoundPredicate[T]: @dataclass(frozen=True) -class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression): +class UnboundPredicate(Generic[T], Unbound[BooleanExpression], BooleanExpression): as_bound: ClassVar[type] term: UnboundTerm[T] @@ -661,12 +626,6 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T: return visitor.visit_and(left_result=left_result, right_result=right_result) -@visit.register(In) -def _(obj: In, visitor: BooleanExpressionVisitor[T]) -> T: - """Visit an In boolean expression with a concrete BooleanExpressionVisitor""" - return visitor.visit_unbound_predicate(predicate=obj) - - @visit.register(UnboundPredicate) def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T: """Visit an In boolean expression with a concrete BooleanExpressionVisitor""" diff --git a/python/tests/expressions/test_expressions_base.py b/python/tests/expressions/test_expressions_base.py index ba2850133bb0..cf7429829642 100644 --- a/python/tests/expressions/test_expressions_base.py +++ b/python/tests/expressions/test_expressions_base.py @@ -120,13 +120,13 @@ def _(obj: ExpressionB, visitor: BooleanExpressionVisitor) -> List: [ ( base.And(ExpressionA(), ExpressionB()), - "And(ExpressionA(), ExpressionB())", + "And(left=ExpressionA(), right=ExpressionB())", ), ( base.Or(ExpressionA(), ExpressionB()), - "Or(ExpressionA(), ExpressionB())", + "Or(left=ExpressionA(), right=ExpressionB())", ), - (base.Not(ExpressionA()), "Not(ExpressionA())"), + (base.Not(ExpressionA()), "Not(child=ExpressionA())"), ], ) def test_reprs(op, rep): @@ -208,9 +208,9 @@ def test_notnan_bind_nonfloat(): @pytest.mark.parametrize( "op, string", [ - (base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"), - (base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"), - (base.Not(ExpressionA()), "Not(testexpra)"), + (base.And(ExpressionA(), ExpressionB()), "And(left=ExpressionA(), right=ExpressionB())"), + (base.Or(ExpressionA(), ExpressionB()), "Or(left=ExpressionA(), right=ExpressionB())"), + (base.Not(ExpressionA()), "Not(child=ExpressionA())"), ], ) def test_strs(op, string):