diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 79bc995198..244e4d2e87 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -534,7 +534,9 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool: ROWS_MIGHT_MATCH = True +ROWS_MUST_MATCH = True ROWS_CANNOT_MATCH = False +ROWS_MIGHT_NOT_MATCH = False IN_PREDICATE_LIMIT = 200 @@ -1089,16 +1091,52 @@ def expression_to_plain_format( return [visit(expression, visitor) for expression in expressions] -class _InclusiveMetricsEvaluator(BoundBooleanExpressionVisitor[bool]): - struct: StructType - expr: BooleanExpression - +class _MetricsEvaluator(BoundBooleanExpressionVisitor[bool], ABC): value_counts: Dict[int, int] null_counts: Dict[int, int] nan_counts: Dict[int, int] lower_bounds: Dict[int, bytes] upper_bounds: Dict[int, bytes] + def visit_true(self) -> bool: + # all rows match + return ROWS_MIGHT_MATCH + + def visit_false(self) -> bool: + # all rows fail + return ROWS_CANNOT_MATCH + + def visit_not(self, child_result: bool) -> bool: + raise ValueError(f"NOT should be rewritten: {child_result}") + + def visit_and(self, left_result: bool, right_result: bool) -> bool: + return left_result and right_result + + def visit_or(self, left_result: bool, right_result: bool) -> bool: + return left_result or right_result + + def _contains_nulls_only(self, field_id: int) -> bool: + if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)): + return value_count == null_count + return False + + def _contains_nans_only(self, field_id: int) -> bool: + if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)): + return nan_count == value_count + return False + + def _is_nan(self, val: Any) -> bool: + try: + return math.isnan(val) + except TypeError: + # In the case of None or other non-numeric types + return False + + +class _InclusiveMetricsEvaluator(_MetricsEvaluator): + struct: StructType + expr: BooleanExpression + def __init__( self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False ) -> None: @@ -1128,40 +1166,11 @@ def eval(self, file: DataFile) -> bool: def _may_contain_null(self, field_id: int) -> bool: return self.null_counts is None or (field_id in self.null_counts and self.null_counts.get(field_id) is not None) - def _contains_nulls_only(self, field_id: int) -> bool: - if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)): - return value_count == null_count - return False - def _contains_nans_only(self, field_id: int) -> bool: if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)): return nan_count == value_count return False - def _is_nan(self, val: Any) -> bool: - try: - return math.isnan(val) - except TypeError: - # In the case of None or other non-numeric types - return False - - def visit_true(self) -> bool: - # all rows match - return ROWS_MIGHT_MATCH - - def visit_false(self) -> bool: - # all rows fail - return ROWS_CANNOT_MATCH - - def visit_not(self, child_result: bool) -> bool: - raise ValueError(f"NOT should be rewritten: {child_result}") - - def visit_and(self, left_result: bool, right_result: bool) -> bool: - return left_result and right_result - - def visit_or(self, left_result: bool, right_result: bool) -> bool: - return left_result or right_result - def visit_is_null(self, term: BoundTerm[L]) -> bool: field_id = term.ref().field.field_id @@ -1421,3 +1430,299 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool return ROWS_CANNOT_MATCH return ROWS_MIGHT_MATCH + + +class _StrictMetricsEvaluator(_MetricsEvaluator): + struct: StructType + expr: BooleanExpression + + def __init__( + self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False + ) -> None: + self.struct = schema.as_struct() + self.include_empty_files = include_empty_files + self.expr = bind(schema, rewrite_not(expr), case_sensitive) + + def eval(self, file: DataFile) -> bool: + """Test whether all records within the file match the expression. + + Args: + file: A data file + + Returns: false if the file may contain any row that doesn't match + the expression, true otherwise. + """ + if file.record_count <= 0: + # Older version don't correctly implement record count from avro file and thus + # set record count -1 when importing avro tables to iceberg tables. This should + # be updated once we implemented and set correct record count. + return ROWS_MUST_MATCH + + self.value_counts = file.value_counts or EMPTY_DICT + self.null_counts = file.null_value_counts or EMPTY_DICT + self.nan_counts = file.nan_value_counts or EMPTY_DICT + self.lower_bounds = file.lower_bounds or EMPTY_DICT + self.upper_bounds = file.upper_bounds or EMPTY_DICT + + return visit(self.expr, self) + + def visit_is_null(self, term: BoundTerm[L]) -> bool: + # no need to check whether the field is required because binding evaluates that case + # if the column has any non-null values, the expression does not match + field_id = term.ref().field.field_id + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + if self._contains_nulls_only(field_id): + return ROWS_MUST_MATCH + else: + return ROWS_MIGHT_NOT_MATCH + + def visit_not_null(self, term: BoundTerm[L]) -> bool: + # no need to check whether the field is required because binding evaluates that case + # if the column has any non-null values, the expression does not match + field_id = term.ref().field.field_id + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0: + return ROWS_MUST_MATCH + else: + return ROWS_MIGHT_NOT_MATCH + + def visit_is_nan(self, term: BoundTerm[L]) -> bool: + field_id = term.ref().field.field_id + + if self._contains_nans_only(field_id): + return ROWS_MUST_MATCH + else: + return ROWS_MIGHT_NOT_MATCH + + def visit_not_nan(self, term: BoundTerm[L]) -> bool: + field_id = term.ref().field.field_id + + if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0: + return ROWS_MUST_MATCH + + if self._contains_nulls_only(field_id): + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when: <----------Min----Max---X-------> + + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + if upper_bytes := self.upper_bounds.get(field_id): + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + upper = _from_byte_buffer(field.field_type, upper_bytes) + + if upper < literal.value: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when: <----------Min----Max---X-------> + + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + if upper_bytes := self.upper_bounds.get(field_id): + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + upper = _from_byte_buffer(field.field_type, upper_bytes) + + if upper <= literal.value: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when: <-------X---Min----Max----------> + + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + if lower_bytes := self.lower_bounds.get(field_id): + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + lower = _from_byte_buffer(field.field_type, lower_bytes) + + if self._is_nan(lower): + # NaN indicates unreliable bounds. + # See the _StrictMetricsEvaluator docs for more. + return ROWS_MIGHT_NOT_MATCH + + if lower > literal.value: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when: <-------X---Min----Max----------> + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + if lower_bytes := self.lower_bounds.get(field_id): + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + lower = _from_byte_buffer(field.field_type, lower_bytes) + + if self._is_nan(lower): + # NaN indicates unreliable bounds. + # See the _StrictMetricsEvaluator docs for more. + return ROWS_MIGHT_NOT_MATCH + + if lower >= literal.value: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when Min == X == Max + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + lower = _from_byte_buffer(field.field_type, lower_bytes) + upper = _from_byte_buffer(field.field_type, upper_bytes) + + if lower != literal.value or upper != literal.value: + return ROWS_MIGHT_NOT_MATCH + else: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + # Rows must match when X < Min or Max < X because it is not in the range + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MUST_MATCH + + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + if lower_bytes := self.lower_bounds.get(field_id): + lower = _from_byte_buffer(field.field_type, lower_bytes) + + if self._is_nan(lower): + # NaN indicates unreliable bounds. + # See the _StrictMetricsEvaluator docs for more. + return ROWS_MIGHT_NOT_MATCH + + if lower > literal.value: + return ROWS_MUST_MATCH + + if upper_bytes := self.upper_bounds.get(field_id): + upper = _from_byte_buffer(field.field_type, upper_bytes) + + if upper < literal.value: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MIGHT_NOT_MATCH + + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): + # similar to the implementation in eq, first check if the lower bound is in the set + lower = _from_byte_buffer(field.field_type, lower_bytes) + if lower not in literals: + return ROWS_MIGHT_NOT_MATCH + + # check if the upper bound is in the set + upper = _from_byte_buffer(field.field_type, upper_bytes) + if upper not in literals: + return ROWS_MIGHT_NOT_MATCH + + # finally check if the lower bound and the upper bound are equal + if lower != upper: + return ROWS_MIGHT_NOT_MATCH + + # All values must be in the set if the lower bound and the upper bound are + # in the set and are equal. + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + field_id = term.ref().field.field_id + + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): + return ROWS_MUST_MATCH + + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + if lower_bytes := self.lower_bounds.get(field_id): + lower = _from_byte_buffer(field.field_type, lower_bytes) + + if self._is_nan(lower): + # NaN indicates unreliable bounds. + # See the StrictMetricsEvaluator docs for more. + return ROWS_MIGHT_NOT_MATCH + + literals = {val for val in literals if lower <= val} + if len(literals) == 0: + return ROWS_MUST_MATCH + + if upper_bytes := self.upper_bounds.get(field_id): + upper = _from_byte_buffer(field.field_type, upper_bytes) + + literals = {val for val in literals if upper >= val} + + if len(literals) == 0: + return ROWS_MUST_MATCH + + return ROWS_MIGHT_NOT_MATCH + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + return ROWS_MIGHT_NOT_MATCH + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + return ROWS_MIGHT_NOT_MATCH + + def _can_contain_nulls(self, field_id: int) -> bool: + return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0 + + def _can_contain_nans(self, field_id: int) -> bool: + return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0 diff --git a/tests/expressions/test_evaluator.py b/tests/expressions/test_evaluator.py index 7d97a6d2d2..d805f122d9 100644 --- a/tests/expressions/test_evaluator.py +++ b/tests/expressions/test_evaluator.py @@ -39,7 +39,7 @@ Or, StartsWith, ) -from pyiceberg.expressions.visitors import _InclusiveMetricsEvaluator +from pyiceberg.expressions.visitors import _InclusiveMetricsEvaluator, _StrictMetricsEvaluator from pyiceberg.manifest import DataFile, FileFormat from pyiceberg.schema import Schema from pyiceberg.types import ( @@ -925,3 +925,530 @@ def test_string_not_starts_with( # should_read = _InclusiveMetricsEvaluator(schema_data_file, NotStartsWith("required", above_max)).eval(data_file_4) # assert should_read, "Should not read: range doesn't match" + + +@pytest.fixture +def strict_data_file_schema() -> Schema: + return Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "no_stats", IntegerType(), required=False), + NestedField(3, "required", StringType(), required=True), + NestedField(4, "all_nulls", StringType(), required=False), + NestedField(5, "some_nulls", StringType(), required=False), + NestedField(6, "no_nulls", StringType(), required=False), + NestedField(7, "always_5", IntegerType(), required=False), + NestedField(8, "all_nans", DoubleType(), required=False), + NestedField(9, "some_nans", FloatType(), required=False), + NestedField(10, "no_nans", FloatType(), required=False), + NestedField(11, "all_nulls_double", DoubleType(), required=False), + NestedField(12, "all_nans_v1_stats", FloatType(), required=False), + NestedField(13, "nan_and_null_only", DoubleType(), required=False), + NestedField(14, "no_nan_stats", DoubleType(), required=False), + ) + + +@pytest.fixture +def strict_data_file_1() -> DataFile: + return DataFile( + file_path="file_1.parquet", + file_format=FileFormat.PARQUET, + partition={}, + record_count=50, + file_size_in_bytes=3, + value_counts={ + 4: 50, + 5: 50, + 6: 50, + 8: 50, + 9: 50, + 10: 50, + 11: 50, + 12: 50, + 13: 50, + 14: 50, + }, + null_value_counts={4: 50, 5: 10, 6: 0, 11: 50, 12: 0, 13: 1}, + nan_value_counts={ + 8: 50, + 9: 10, + 10: 0, + }, + lower_bounds={ + 1: to_bytes(IntegerType(), INT_MIN_VALUE), + 7: to_bytes(IntegerType(), 5), + 12: to_bytes(FloatType(), float("nan")), + 13: to_bytes(DoubleType(), float("nan")), + }, + upper_bounds={ + 1: to_bytes(IntegerType(), INT_MAX_VALUE), + 7: to_bytes(IntegerType(), 5), + 12: to_bytes(FloatType(), float("nan")), + 14: to_bytes(DoubleType(), float("nan")), + }, + ) + + +@pytest.fixture +def strict_data_file_2() -> DataFile: + return DataFile( + file_path="file_2.parquet", + file_format=FileFormat.PARQUET, + partition={}, + record_count=50, + file_size_in_bytes=3, + value_counts={ + 4: 50, + 5: 50, + 6: 50, + 8: 50, + }, + null_value_counts={4: 50, 5: 10, 6: 0}, + nan_value_counts=None, + lower_bounds={ + 5: to_bytes(StringType(), "bbb"), + }, + upper_bounds={ + 5: to_bytes(StringType(), "eee"), + }, + ) + + +@pytest.fixture +def strict_data_file_3() -> DataFile: + return DataFile( + file_path="file_3.parquet", + file_format=FileFormat.PARQUET, + partition={}, + record_count=50, + file_size_in_bytes=3, + value_counts={ + 4: 50, + 5: 50, + 6: 50, + }, + null_value_counts={4: 50, 5: 10, 6: 0}, + nan_value_counts=None, + lower_bounds={ + 5: to_bytes(StringType(), "bbb"), + }, + upper_bounds={ + 5: to_bytes(StringType(), "eee"), + }, + ) + + +def test_strict_all_nulls(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNull("all_nulls")).eval(strict_data_file_1) + assert not should_read, "Should not match: no non-null value in all null column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNull("some_nulls")).eval(strict_data_file_1) + assert not should_read, "Should not match: column with some nulls contains a non-null value" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNull("no_nulls")).eval(strict_data_file_1) + assert should_read, "Should match: non-null column contains no null values" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("all_nulls", "a")).eval(strict_data_file_1) + assert should_read, "Should match: notEqual on all nulls column" + + +def test_strict_no_nulls(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNull("all_nulls")).eval(strict_data_file_1) + assert should_read, "Should match: all values are null" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNull("some_nulls")).eval(strict_data_file_1) + assert not should_read, "Should not match: not all values are null" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNull("no_nulls")).eval(strict_data_file_1) + assert not should_read, "Should not match: no values are null" + + +def test_strict_some_nulls(strict_data_file_schema: Schema, strict_data_file_2: DataFile, strict_data_file_3: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThan("some_nulls", "ggg")).eval(strict_data_file_2) + assert not should_read, "Should not match: lessThan on some nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThanOrEqual("some_nulls", "ggg")).eval(strict_data_file_2) + assert not should_read, "Should not match: lessThanOrEqual on some nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThan("some_nulls", "aaa")).eval(strict_data_file_2) + assert not should_read, "Should not match: greaterThan on some nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThanOrEqual("some_nulls", "bbb")).eval( + strict_data_file_2 + ) + assert not should_read, "Should not match: greaterThanOrEqual on some nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("some_nulls", "bbb")).eval(strict_data_file_3) + assert not should_read, "Should not match: equal on some nulls column" + + +def test_strict_is_nan(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("all_nans")).eval(strict_data_file_1) + assert should_read, "Should match: all values are nan" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("some_nans")).eval(strict_data_file_1) + assert not should_read, "Should not match: at least one non-nan value in some nan column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("no_nans")).eval(strict_data_file_1) + assert not should_read, "Should not match: at least one non-nan value in no nan column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("all_nulls_double")).eval(strict_data_file_1) + assert not should_read, "Should not match: at least one non-nan value in all null column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("no_nan_stats")).eval(strict_data_file_1) + assert not should_read, "Should not match: cannot determine without nan stats" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("all_nans_v1_stats")).eval(strict_data_file_1) + assert not should_read, "Should not match: cannot determine without nan stats" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNaN("nan_and_null_only")).eval(strict_data_file_1) + assert not should_read, "Should not match: null values are not nan" + + +def test_strict_not_nan(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("all_nans")).eval(strict_data_file_1) + # assert not should_read, "Should not match: all values are nan" + # + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("some_nans")).eval(strict_data_file_1) + # assert not should_read, "Should not match: at least one nan value in some nan column" + # + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("no_nans")).eval(strict_data_file_1) + # assert should_read, "Should match: no value is nan" + # + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("all_nulls_double")).eval(strict_data_file_1) + # assert should_read, "Should match: no nan value in all null column" + # + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("no_nan_stats")).eval(strict_data_file_1) + # assert not should_read, "Should not match: cannot determine without nan stats" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("all_nans_v1_stats")).eval(strict_data_file_1) + assert not should_read, "Should not match: all values are nan" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNaN("nan_and_null_only")).eval(strict_data_file_1) + assert not should_read, "Should not match: null values are not nan" + + +def test_strict_required_column(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotNull("required")).eval(strict_data_file_1) + assert should_read, "Should match: required columns are always non-null" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, IsNull("required")).eval(strict_data_file_1) + assert not should_read, "Should not match: required columns never contain null" + + +def test_strict_missing_column(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + with pytest.raises(ValueError) as exc_info: + _ = _StrictMetricsEvaluator(strict_data_file_schema, NotNull("missing")).eval(strict_data_file_1) + + assert str(exc_info.value) == "Could not find field with name missing, case_sensitive=True" + + +def test_strict_missing_stats(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + no_stats_schema = Schema( + NestedField(2, "no_stats", DoubleType(), required=False), + ) + + no_stats_file = DataFile( + file_path="file_1.parquet", + file_format=FileFormat.PARQUET, + partition={}, + record_count=50, + value_counts=None, + null_value_counts=None, + nan_value_counts=None, + lower_bounds=None, + upper_bounds=None, + ) + + expressions = [ + LessThan("no_stats", 5), + LessThanOrEqual("no_stats", 30), + EqualTo("no_stats", 70), + GreaterThan("no_stats", 78), + GreaterThanOrEqual("no_stats", 90), + NotEqualTo("no_stats", 101), + IsNull("no_stats"), + NotNull("no_stats"), + IsNaN("no_stats"), + NotNaN("no_stats"), + ] + + for expression in expressions: + should_read = _StrictMetricsEvaluator(no_stats_schema, expression).eval(no_stats_file) + assert not should_read, f"Should never match when stats are missing for expr: {expression}" + + +def test_strict_zero_record_file_stats(strict_data_file_schema: Schema) -> None: + zero_record_data_file = DataFile(file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition={}, record_count=0) + + expressions = [ + LessThan("no_stats", 5), + LessThanOrEqual("no_stats", 30), + EqualTo("no_stats", 70), + GreaterThan("no_stats", 78), + GreaterThanOrEqual("no_stats", 90), + NotEqualTo("no_stats", 101), + IsNull("no_stats"), + NotNull("no_stats"), + IsNaN("no_stats"), + NotNaN("no_stats"), + ] + + for expression in expressions: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, expression).eval(zero_record_data_file) + assert should_read, f"Should always match 0-record file: {expression}" + + +def test_strict_not(schema_data_file: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(schema_data_file, Not(LessThan("id", INT_MIN_VALUE - 25))).eval(strict_data_file_1) + assert should_read, "Should not match: not(false)" + + should_read = _StrictMetricsEvaluator(schema_data_file, Not(GreaterThan("id", INT_MIN_VALUE - 25))).eval(strict_data_file_1) + assert not should_read, "Should match: not(true)" + + +def test_strict_and(schema_data_file: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator( + schema_data_file, And(GreaterThan("id", INT_MIN_VALUE - 25), LessThanOrEqual("id", INT_MIN_VALUE)) + ).eval(strict_data_file_1) + assert not should_read, "Should not match: range may not overlap data" + + should_read = _StrictMetricsEvaluator( + schema_data_file, And(LessThan("id", INT_MIN_VALUE - 25), GreaterThanOrEqual("id", INT_MIN_VALUE - 30)) + ).eval(strict_data_file_1) + assert not should_read, "Should not match: range does not overlap data" + + should_read = _StrictMetricsEvaluator( + schema_data_file, And(LessThan("id", INT_MAX_VALUE + 6), GreaterThanOrEqual("id", INT_MIN_VALUE - 30)) + ).eval(strict_data_file_1) + assert should_read, "Should match: range includes all data" + + +def test_strict_or(schema_data_file: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator( + schema_data_file, Or(LessThan("id", INT_MIN_VALUE - 25), GreaterThanOrEqual("id", INT_MAX_VALUE + 1)) + ).eval(strict_data_file_1) + assert not should_read, "Should not match: no matching values" + + should_read = _StrictMetricsEvaluator( + schema_data_file, Or(LessThan("id", INT_MIN_VALUE - 25), GreaterThanOrEqual("id", INT_MAX_VALUE - 19)) + ).eval(strict_data_file_1) + assert not should_read, "Should not match: some values do not match" + + should_read = _StrictMetricsEvaluator( + schema_data_file, Or(LessThan("id", INT_MIN_VALUE - 25), GreaterThanOrEqual("id", INT_MIN_VALUE)) + ).eval(strict_data_file_1) + assert should_read, "Should match: all values match >= 30" + + +def test_strict_integer_lt(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThan("id", INT_MIN_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: always false" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThan("id", INT_MIN_VALUE + 1)).eval(strict_data_file_1) + assert not should_read, "Should not match: 32 and greater not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThan("id", INT_MAX_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: 79 not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThan("id", INT_MAX_VALUE + 1)).eval(strict_data_file_1) + assert should_read, "Should match: all values in range" + + +def test_strict_integer_lt_eq(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThanOrEqual("id", INT_MIN_VALUE - 1)).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: always false" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThanOrEqual("id", INT_MIN_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: 31 and greater not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThanOrEqual("id", INT_MAX_VALUE)).eval(strict_data_file_1) + assert should_read, "Should match: all values in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, LessThanOrEqual("id", INT_MAX_VALUE + 1)).eval( + strict_data_file_1 + ) + assert should_read, "Should match: all values in range" + + +def test_strict_integer_gt(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThan("id", INT_MAX_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: always false" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThan("id", INT_MAX_VALUE - 1)).eval(strict_data_file_1) + assert not should_read, "Should not match: 77 and less not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThan("id", INT_MIN_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: 30 not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThan("id", INT_MIN_VALUE - 1)).eval(strict_data_file_1) + assert should_read, "Should match: all values in range" + + +def test_strict_integer_gt_eq(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThanOrEqual("id", INT_MAX_VALUE + 1)).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: no values in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThanOrEqual("id", INT_MAX_VALUE)).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: 78 and lower are not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThanOrEqual("id", INT_MIN_VALUE + 1)).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: 30 not in range" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, GreaterThanOrEqual("id", INT_MIN_VALUE)).eval( + strict_data_file_1 + ) + assert should_read, "Should match: all values in range" + + +def test_strict_integer_eq(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("id", INT_MIN_VALUE - 25)).eval(strict_data_file_1) + assert not should_read, "Should not match: all values != 5" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("id", INT_MIN_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: some values != 30" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("id", INT_MAX_VALUE - 4)).eval(strict_data_file_1) + assert not should_read, "Should not match: some values != 75" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("id", INT_MAX_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: some values != 79" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("id", INT_MAX_VALUE + 1)).eval(strict_data_file_1) + assert not should_read, "Should not match: some values != 80" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, EqualTo("always_5", INT_MIN_VALUE - 25)).eval( + strict_data_file_1 + ) + assert should_read, "Should match: all values == 5" + + +def test_strict_integer_not_eq(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MIN_VALUE - 25)).eval(strict_data_file_1) + assert should_read, "Should match: no values == 5" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MIN_VALUE - 1)).eval(strict_data_file_1) + assert should_read, "Should match: no values == 39" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MIN_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 30" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MAX_VALUE - 4)).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 75" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MAX_VALUE)).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 79" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MAX_VALUE + 1)).eval(strict_data_file_1) + assert should_read, "Should match: no values == 80" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotEqualTo("id", INT_MAX_VALUE + 6)).eval(strict_data_file_1) + assert should_read, "Should read: no values == 85" + + +def test_strict_integer_not_eq_rewritten(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MIN_VALUE - 25))).eval( + strict_data_file_1 + ) + assert should_read, "Should match: no values == 5" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MIN_VALUE - 1))).eval(strict_data_file_1) + assert should_read, "Should match: no values == 39" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MIN_VALUE))).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 30" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MAX_VALUE - 4))).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 75" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MAX_VALUE))).eval(strict_data_file_1) + assert not should_read, "Should not match: some value may be == 79" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MAX_VALUE + 1))).eval(strict_data_file_1) + assert should_read, "Should match: no values == 80" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, Not(EqualTo("id", INT_MAX_VALUE + 6))).eval(strict_data_file_1) + assert should_read, "Should read: no values == 85" + + +def test_strict_integer_in(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("id", {INT_MIN_VALUE - 25, INT_MIN_VALUE - 24})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: all values != 5 and != 6" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("id", {INT_MIN_VALUE - 1, INT_MIN_VALUE})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some values != 30 and != 31" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("id", {INT_MAX_VALUE - 4, INT_MAX_VALUE - 3})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some values != 75 and != 76" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("id", {INT_MAX_VALUE, INT_MAX_VALUE + 1})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some values != 78 and != 79" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("id", {INT_MAX_VALUE + 1, INT_MAX_VALUE + 2})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some values != 80 and != 81)" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("always_5", {5, 6})).eval(strict_data_file_1) + assert should_read, "Should match: all values == 5" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("all_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert not should_read, "Should not match: in on all nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("some_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert not should_read, "Should not match: in on some nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, In("no_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert not should_read, "Should not match: no_nulls field does not have bounds" + + +def test_strict_integer_not_in(strict_data_file_schema: Schema, strict_data_file_1: DataFile) -> None: + # should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("id", {INT_MIN_VALUE - 25, INT_MIN_VALUE - 24})).eval(strict_data_file_1) + # assert should_read, "Should match: all values != 5 and != 6" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("id", {INT_MIN_VALUE - 1, INT_MIN_VALUE})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some values may be == 30" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("id", {INT_MAX_VALUE - 4, INT_MAX_VALUE - 3})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some value may be == 75 or == 76" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("id", {INT_MAX_VALUE, INT_MAX_VALUE + 1})).eval( + strict_data_file_1 + ) + assert not should_read, "Should not match: some value may be == 79" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("id", {INT_MAX_VALUE + 1, INT_MAX_VALUE + 2})).eval( + strict_data_file_1 + ) + assert should_read, "Should match: no values == 80 or == 81" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("always_5", {5, 6})).eval(strict_data_file_1) + assert not should_read, "Should not match: all values == 5" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("all_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert should_read, "Should match: notIn on all nulls column" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("some_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert should_read, "Should match: notIn on some nulls column, 'bbb' > 'abc' and 'bbb' < 'def'" + + should_read = _StrictMetricsEvaluator(strict_data_file_schema, NotIn("no_nulls", {"abc", "def"})).eval(strict_data_file_1) + assert not should_read, "Should not match: no_nulls field does not have bounds"