Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 49 additions & 90 deletions python/pyiceberg/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,43 @@


class BooleanExpression(ABC):
"""Represents a boolean expression tree."""
"""An expression that evaluates to a boolean"""

@abstractmethod
def __invert__(self) -> BooleanExpression:
"""Transform the Expression into its negated version."""


class Bound(Generic[T], ABC):
"""Represents a bound value expression."""
class Term(Generic[T], ABC):
"""A simple expression that evaluates to a value"""

def eval(self, struct: StructProtocol): # pylint: disable=W0613
... # pragma: no cover

class Bound(ABC):
"""Represents a bound value expression"""


class Unbound(Generic[T, B], ABC):
"""Represents an unbound expression node."""
class Unbound(Generic[B], ABC):
"""Represents an unbound value expression"""

@abstractmethod
def bind(self, schema: Schema, case_sensitive: bool) -> B:
def bind(self, schema: Schema, case_sensitive: bool = True) -> B:
... # pragma: no cover


class Term(ABC):
"""An expression that evaluates to a value."""


class BaseReference(Generic[T], Term, ABC):
"""Represents a variable reference in an expression."""


class BoundTerm(Bound[T], Term):
"""Represents a bound term."""
class BoundTerm(Term[T], Bound, ABC):
"""Represents a bound term"""

@abstractmethod
def ref(self) -> BoundReference[T]:
...


class UnboundTerm(Unbound[T, BoundTerm[T]], Term):
"""Represents an unbound term."""
@abstractmethod
def eval(self, struct: StructProtocol): # pylint: disable=W0613
... # pragma: no cover


@dataclass(frozen=True)
class BoundReference(BoundTerm[T], BaseReference[T]):
class BoundReference(BoundTerm[T]):
"""A reference bound to a field in a schema

Args:
Expand All @@ -88,6 +81,7 @@ class BoundReference(BoundTerm[T], BaseReference[T]):

def eval(self, struct: StructProtocol) -> T:
"""Returns the value at the referenced field's position in an object that abides by the StructProtocol

Args:
struct (StructProtocol): A row object that abides by the StructProtocol and returns values given a position
Returns:
Expand All @@ -99,8 +93,12 @@ def ref(self) -> BoundReference[T]:
return self


class UnboundTerm(Term[T], Unbound[BoundTerm[T]], ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit / non-blocking: An Unbound[BoundTerm[T]] is somewhat of a confusing type signature (particularly coming more from the Java side).

However UnboundTerm itself is clear as day so that's not particularly a concern.

"""Represents an unbound term."""


@dataclass(frozen=True)
class Reference(UnboundTerm[T], BaseReference[T]):
class Reference(UnboundTerm[T]):
"""A reference not yet bound to a field in a schema

Args:
Expand All @@ -112,7 +110,7 @@ class Reference(UnboundTerm[T], BaseReference[T]):

name: str

def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]:
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[T]:
"""Bind the reference to an Iceberg schema

Args:
Expand All @@ -125,22 +123,24 @@ def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]:
Returns:
BoundReference: A reference bound to the specific field in the Iceberg schema
"""
field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive) # pylint: disable=redefined-outer-name

field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive)
if not field:
raise ValueError(f"Cannot find field '{self.name}' in schema: {schema}")

accessor = schema.accessor_for_field(field.field_id)

if not accessor:
raise ValueError(f"Cannot find accessor for field '{self.name}' in schema: {schema}")

return BoundReference(field=field, accessor=accessor)


@dataclass(frozen=True, init=False)
class And(BooleanExpression):
"""AND operation expression - logical conjunction"""

left: BooleanExpression
right: BooleanExpression

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
if rest:
return reduce(And, (left, right, *rest))
Expand All @@ -150,35 +150,23 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole
return right
elif right is AlwaysTrue():
return left
self = super().__new__(cls)
self._left = left # type: ignore
self._right = right # type: ignore
return self

@property
def left(self) -> BooleanExpression:
return self._left # type: ignore

@property
def right(self) -> BooleanExpression:
return self._right # type: ignore

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, And) and self.left == other.left and self.right == other.right)
else:
result = super().__new__(cls)
object.__setattr__(result, "left", left)
object.__setattr__(result, "right", right)
return result

def __invert__(self) -> Or:
return Or(~self.left, ~self.right)

def __repr__(self) -> str:
return f"And({repr(self.left)}, {repr(self.right)})"

def __str__(self) -> str:
return f"And({str(self.left)}, {str(self.right)})"


@dataclass(frozen=True, init=False)
class Or(BooleanExpression):
"""OR operation expression - logical disjunction"""

left: BooleanExpression
right: BooleanExpression

def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
if rest:
return reduce(Or, (left, right, *rest))
Expand All @@ -188,59 +176,36 @@ def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: Boole
return right
elif right is AlwaysFalse():
return left
self = super().__new__(cls)
self._left = left # type: ignore
self._right = right # type: ignore
return self

@property
def left(self) -> BooleanExpression:
return self._left # type: ignore

@property
def right(self) -> BooleanExpression:
return self._right # type: ignore

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, Or) and self.left == other.left and self.right == other.right)
else:
result = super().__new__(cls)
object.__setattr__(result, "left", left)
object.__setattr__(result, "right", right)
return result

def __invert__(self) -> And:
return And(~self.left, ~self.right)

def __repr__(self) -> str:
return f"Or({repr(self.left)}, {repr(self.right)})"

def __str__(self) -> str:
return f"Or({str(self.left)}, {str(self.right)})"


@dataclass(frozen=True, init=False)
class Not(BooleanExpression):
"""NOT operation expression - logical negation"""

child: BooleanExpression

def __new__(cls, child: BooleanExpression):
if child is AlwaysTrue():
return AlwaysFalse()
elif child is AlwaysFalse():
return AlwaysTrue()
elif isinstance(child, Not):
return child.child
return super().__new__(cls)

def __init__(self, child):
self.child = child

def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, Not) and self.child == other.child)
result = super().__new__(cls)
object.__setattr__(result, "child", child)
return result

def __invert__(self) -> BooleanExpression:
return self.child

def __repr__(self) -> str:
return f"Not({repr(self.child)})"

def __str__(self) -> str:
return f"Not({str(self.child)})"


@dataclass(frozen=True)
class AlwaysTrue(BooleanExpression, Singleton):
Expand All @@ -259,15 +224,15 @@ def __invert__(self) -> AlwaysTrue:


@dataclass(frozen=True)
class BoundPredicate(Bound[T], BooleanExpression):
class BoundPredicate(Generic[T], Bound, BooleanExpression):
term: BoundTerm[T]

def __invert__(self) -> BoundPredicate[T]:
raise NotImplementedError


@dataclass(frozen=True)
class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression):
class UnboundPredicate(Generic[T], Unbound[BooleanExpression], BooleanExpression):
as_bound: ClassVar[type]
term: UnboundTerm[T]

Expand Down Expand Up @@ -661,12 +626,6 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T:
return visitor.visit_and(left_result=left_result, right_result=right_result)


@visit.register(In)
def _(obj: In, visitor: BooleanExpressionVisitor[T]) -> T:
"""Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
return visitor.visit_unbound_predicate(predicate=obj)


@visit.register(UnboundPredicate)
def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T:
"""Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
Expand Down
12 changes: 6 additions & 6 deletions python/tests/expressions/test_expressions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def _(obj: ExpressionB, visitor: BooleanExpressionVisitor) -> List:
[
(
base.And(ExpressionA(), ExpressionB()),
"And(ExpressionA(), ExpressionB())",
"And(left=ExpressionA(), right=ExpressionB())",
),
(
base.Or(ExpressionA(), ExpressionB()),
"Or(ExpressionA(), ExpressionB())",
"Or(left=ExpressionA(), right=ExpressionB())",
),
(base.Not(ExpressionA()), "Not(ExpressionA())"),
(base.Not(ExpressionA()), "Not(child=ExpressionA())"),
],
)
def test_reprs(op, rep):
Expand Down Expand Up @@ -208,9 +208,9 @@ def test_notnan_bind_nonfloat():
@pytest.mark.parametrize(
"op, string",
[
(base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"),
(base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"),
(base.Not(ExpressionA()), "Not(testexpra)"),
(base.And(ExpressionA(), ExpressionB()), "And(left=ExpressionA(), right=ExpressionB())"),
(base.Or(ExpressionA(), ExpressionB()), "Or(left=ExpressionA(), right=ExpressionB())"),
(base.Not(ExpressionA()), "Not(child=ExpressionA())"),
Copy link
Contributor

@dramaticlly dramaticlly Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since updating this file, maybe worth try simply LINE 220 to 245 as Sam suggested in https://github.com/apache/iceberg/pull/5362/files#r932825540

def test_ref_binding_case_sensitive(request):
    schema = request.getfixturevalue("table_schema_simple")

can be replaced by python

def test_ref_binding_case_sensitive(table_schema_simple):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't make that change because I think it is easier to understand what's happening when we use less magic. Using getfixturevalue makes it clear that this is getting a fixture called "table_schema_simple". Otherwise where that comes from is magic and confusing (at least to me).

],
)
def test_strs(op, string):
Expand Down