Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix equality of bound expressions #95

Merged
merged 1 commit into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, term: BoundTerm[L]):

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundPredicate class."""
if isinstance(other, BoundPredicate):
if isinstance(other, self.__class__):
return self.term == other.term
return False

Expand Down Expand Up @@ -567,7 +567,7 @@ def __repr__(self) -> str:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundSetPredicate class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundSetPredicate) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

def __getnewargs__(self) -> Tuple[BoundTerm[L], Set[Literal[L]]]:
"""Pickle the BoundSetPredicate class."""
Expand Down Expand Up @@ -595,7 +595,7 @@ def __invert__(self) -> BoundNotIn[L]:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundIn class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundIn) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

@property
def as_unbound(self) -> Type[In[L]]:
Expand Down Expand Up @@ -725,7 +725,7 @@ def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable=

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundLiteralPredicate class."""
if isinstance(other, BoundLiteralPredicate):
if isinstance(other, self.__class__):
return self.term == other.term and self.literal == other.literal
return False

Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from pyiceberg import schema
from pyiceberg.catalog import Catalog
from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import BoundReference
from pyiceberg.io import (
GCS_ENDPOINT,
GCS_PROJECT_ID,
Expand All @@ -69,7 +70,7 @@
)
from pyiceberg.io.fsspec import FsspecFileIO
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import Schema
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
Expand Down Expand Up @@ -1659,3 +1660,8 @@ def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))
9 changes: 9 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,15 @@ def test_above_long_bounds_greater_than_or_equal(
assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue()


def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None:
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) != BoundGreaterThanOrEqual(
term=bound_reference_str, literal=literal('a')
)
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) == BoundEqualTo(
term=bound_reference_str, literal=literal('a')
)


# __ __ ___
# | \/ |_ _| _ \_ _
# | |\/| | || | _/ || |
Expand Down
5 changes: 0 additions & 5 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,6 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr:
assert repr(transform) == transform_repr


@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture
def bound_reference_date() -> BoundReference[int]:
return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None))
Expand Down