Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StrictMetricsEvaluator #518

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 319 additions & 33 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
DoubleType,
FloatType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
Expand Down Expand Up @@ -534,7 +535,9 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool:


ROWS_MIGHT_MATCH = True
ROWS_MUST_MATCH = True
ROWS_CANNOT_MATCH = False
ROWS_MIGHT_NOT_MATCH = False
IN_PREDICATE_LIMIT = 200


Expand Down Expand Up @@ -1089,16 +1092,52 @@ def expression_to_plain_format(
return [visit(expression, visitor) for expression in expressions]


class _InclusiveMetricsEvaluator(BoundBooleanExpressionVisitor[bool]):
struct: StructType
expr: BooleanExpression

class _MetricsEvaluator(BoundBooleanExpressionVisitor[bool], ABC):
value_counts: Dict[int, int]
null_counts: Dict[int, int]
nan_counts: Dict[int, int]
lower_bounds: Dict[int, bytes]
upper_bounds: Dict[int, bytes]

def visit_true(self) -> bool:
# all rows match
return ROWS_MIGHT_MATCH

def visit_false(self) -> bool:
# all rows fail
return ROWS_CANNOT_MATCH

def visit_not(self, child_result: bool) -> bool:
raise ValueError(f"NOT should be rewritten: {child_result}")

def visit_and(self, left_result: bool, right_result: bool) -> bool:
return left_result and right_result

def visit_or(self, left_result: bool, right_result: bool) -> bool:
return left_result or right_result

def _contains_nulls_only(self, field_id: int) -> bool:
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
return value_count == null_count
return False

def _contains_nans_only(self, field_id: int) -> bool:
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
return nan_count == value_count
return False

def _is_nan(self, val: Any) -> bool:
try:
return math.isnan(val)
except TypeError:
# In the case of None or other non-numeric types
return False


class _InclusiveMetricsEvaluator(_MetricsEvaluator):
struct: StructType
expr: BooleanExpression

def __init__(
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
) -> None:
Expand Down Expand Up @@ -1128,40 +1167,11 @@ def eval(self, file: DataFile) -> bool:
def _may_contain_null(self, field_id: int) -> bool:
return self.null_counts is None or (field_id in self.null_counts and self.null_counts.get(field_id) is not None)

def _contains_nulls_only(self, field_id: int) -> bool:
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
return value_count == null_count
return False

def _contains_nans_only(self, field_id: int) -> bool:
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
return nan_count == value_count
return False

def _is_nan(self, val: Any) -> bool:
try:
return math.isnan(val)
except TypeError:
# In the case of None or other non-numeric types
return False

def visit_true(self) -> bool:
# all rows match
return ROWS_MIGHT_MATCH

def visit_false(self) -> bool:
# all rows fail
return ROWS_CANNOT_MATCH

def visit_not(self, child_result: bool) -> bool:
raise ValueError(f"NOT should be rewritten: {child_result}")

def visit_and(self, left_result: bool, right_result: bool) -> bool:
return left_result and right_result

def visit_or(self, left_result: bool, right_result: bool) -> bool:
return left_result or right_result

def visit_is_null(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

Expand Down Expand Up @@ -1421,3 +1431,279 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
return ROWS_CANNOT_MATCH

return ROWS_MIGHT_MATCH


class _StrictMetricsEvaluator(_MetricsEvaluator):
struct: StructType
expr: BooleanExpression

def __init__(
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
) -> None:
self.struct = schema.as_struct()
self.include_empty_files = include_empty_files
self.expr = bind(schema, rewrite_not(expr), case_sensitive)

def eval(self, file: DataFile) -> bool:
"""Test whether all records within the file match the expression.
Args:
file: A data file
Returns: false if the file may contain any row that doesn't match
the expression, true otherwise.
"""
if file.record_count <= 0:
# Older version don't correctly implement record count from avro file and thus
# set record count -1 when importing avro tables to iceberg tables. This should
# be updated once we implemented and set correct record count.
return ROWS_MUST_MATCH

self.value_counts = file.value_counts or EMPTY_DICT
self.null_counts = file.null_value_counts or EMPTY_DICT
self.nan_counts = file.nan_value_counts or EMPTY_DICT
self.lower_bounds = file.lower_bounds or EMPTY_DICT
self.upper_bounds = file.upper_bounds or EMPTY_DICT

return visit(self.expr, self)

def visit_is_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id

if self._contains_nulls_only(field_id):
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_not_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id

if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0:
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_is_nan(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

if self._contains_nans_only(field_id):
return ROWS_MUST_MATCH
else:
return ROWS_MIGHT_NOT_MATCH

def visit_not_nan(self, term: BoundTerm[L]) -> bool:
field_id = term.ref().field.field_id

if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0:
return ROWS_MUST_MATCH

if self._contains_nulls_only(field_id):
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <----------Min----Max---X------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <----------Min----Max---X------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper <= literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <-------X---Min----Max---------->

field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower > literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when: <-------X---Min----Max---------->
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower >= literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when Min == X == Max
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if lower != literal.value or upper != literal.value:
return ROWS_MIGHT_NOT_MATCH
else:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
# Rows must match when X < Min or Max < X because it is not in the range
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the _StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

if lower > literal.value:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

field = self._get_field(field_id)

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
# similar to the implementation in eq, first check if the lower bound is in the set
lower = _from_byte_buffer(field.field_type, lower_bytes)
if lower not in literals:
return ROWS_MIGHT_NOT_MATCH

# check if the upper bound is in the set
upper = _from_byte_buffer(field.field_type, upper_bytes)
if upper not in literals:
return ROWS_MIGHT_NOT_MATCH

# finally check if the lower bound and the upper bound are equal
if lower != upper:
return ROWS_MIGHT_NOT_MATCH

# All values must be in the set if the lower bound and the upper bound are
# in the set and are equal.
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
field_id = term.ref().field.field_id

if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
# NaN indicates unreliable bounds.
# See the StrictMetricsEvaluator docs for more.
return ROWS_MIGHT_NOT_MATCH

literals = {val for val in literals if lower <= val}
if len(literals) == 0:
return ROWS_MUST_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
upper = _from_byte_buffer(field.field_type, upper_bytes)

literals = {val for val in literals if upper >= val}

if len(literals) == 0:
return ROWS_MUST_MATCH

return ROWS_MIGHT_NOT_MATCH

def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

def _get_field(self, field_id: int) -> NestedField:
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

return field

def _can_contain_nulls(self, field_id: int) -> bool:
return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0

def _can_contain_nans(self, field_id: int) -> bool:
return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0
Loading