From 964f7240e5c05c83c16865bd0966d505d0f06378 Mon Sep 17 00:00:00 2001 From: Drew Gallardo Date: Sat, 20 Sep 2025 11:09:10 -0700 Subject: [PATCH] Parquet, Data, Spark: Fix variant type filtering in ParquetMetricsRowGroupFilter (#14081) (cherry picked from commit fb63af014c39f687f3016a1b3c8bde4a494e9a4a) --- .../data/TestMetricsRowGroupFilter.java | 147 +++++++++++++----- .../parquet/ParquetMetricsRowGroupFilter.java | 9 +- .../iceberg/spark/SparkTestHelperBase.java | 8 + .../iceberg/spark/sql/TestFilterPushDown.java | 89 ++++++++++- 4 files changed, 212 insertions(+), 41 deletions(-) diff --git a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java index 384dcacd10cf..e12015d5eb73 100644 --- a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java +++ b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java @@ -18,7 +18,6 @@ */ package org.apache.iceberg.data; -import static org.apache.iceberg.avro.AvroSchemaUtil.convert; import static org.apache.iceberg.expressions.Expressions.and; import static org.apache.iceberg.expressions.Expressions.equal; import static org.apache.iceberg.expressions.Expressions.greaterThan; @@ -49,8 +48,6 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; -import org.apache.avro.generic.GenericData.Record; -import org.apache.avro.generic.GenericRecordBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.iceberg.FileFormat; @@ -59,9 +56,9 @@ import org.apache.iceberg.ParameterizedTestExtension; import org.apache.iceberg.Parameters; import org.apache.iceberg.Schema; -import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.data.orc.GenericOrcReader; import org.apache.iceberg.data.orc.GenericOrcWriter; +import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.io.CloseableIterable; @@ -78,6 +75,10 @@ import org.apache.iceberg.types.Types.FloatType; import org.apache.iceberg.types.Types.IntegerType; import org.apache.iceberg.types.Types.StringType; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.Variants; import org.apache.orc.OrcFile; import org.apache.orc.Reader; import org.apache.parquet.hadoop.ParquetFileReader; @@ -138,6 +139,11 @@ public static List parameters() { optional(16, "_no_nans", Types.DoubleType.get()), optional(17, "_some_double_nans", Types.DoubleType.get())); + private static final Schema VARIANT_SCHEMA = + new Schema( + required(1, "id", IntegerType.get()), + optional(2, "variant_field", Types.VariantType.get())); + private static final String TOO_LONG_FOR_STATS_PARQUET; static { @@ -220,40 +226,32 @@ public void createOrcInputFile() throws IOException { } private void createParquetInputFile() throws IOException { - File parquetFile = new File(tempDir, "junit" + System.nanoTime()); - - // build struct field schema - org.apache.avro.Schema structSchema = AvroSchemaUtil.convert(UNDERSCORE_STRUCT_FIELD_TYPE); - - OutputFile outFile = Files.localOutput(parquetFile); - try (FileAppender appender = Parquet.write(outFile).schema(FILE_SCHEMA).build()) { - GenericRecordBuilder builder = new GenericRecordBuilder(convert(FILE_SCHEMA, "table")); - // create 50 records - for (int i = 0; i < INT_MAX_VALUE - INT_MIN_VALUE + 1; i += 1) { - builder.set("_id", INT_MIN_VALUE + i); // min=30, max=79, num-nulls=0 - builder.set( - "_no_stats_parquet", - TOO_LONG_FOR_STATS_PARQUET); // value longer than 4k will produce no stats - // in Parquet - builder.set("_required", "req"); // required, always non-null - builder.set("_all_nulls", null); // never non-null - builder.set("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values - builder.set("_no_nulls", ""); // optional, but always non-null - builder.set("_all_nans", Double.NaN); // never non-nan - builder.set("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values - builder.set( - "_some_double_nans", (i % 10 == 0) ? Double.NaN : 2D); // includes some nan values - builder.set("_no_nans", 3D); // optional, but always non-nan - builder.set("_str", i + "str" + i); - - Record structNotNull = new Record(structSchema); - structNotNull.put("_int_field", INT_MIN_VALUE + i); - builder.set("_struct_not_null", structNotNull); // struct with int - - appender.add(builder.build()); - } + List records = Lists.newArrayList(); + + for (int i = 0; i < INT_MAX_VALUE - INT_MIN_VALUE + 1; i += 1) { + GenericRecord builder = GenericRecord.create(FILE_SCHEMA); + builder.setField("_id", INT_MIN_VALUE + i); // min=30, max=79, num-nulls=0 + builder.setField( + "_no_stats_parquet", + TOO_LONG_FOR_STATS_PARQUET); // value longer than 4k will produce no stats + // in Parquet + builder.setField("_required", "req"); // required, always non-null + builder.setField("_all_nulls", null); // never non-null + builder.setField("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values + builder.setField("_no_nulls", ""); // optional, but always non-null + builder.setField("_all_nans", Double.NaN); // never non-nan + builder.setField("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values + builder.setField( + "_some_double_nans", (i % 10 == 0) ? Double.NaN : 2D); // includes some nan values + builder.setField("_no_nans", 3D); // optional, but always non-nan + builder.setField("_str", i + "str" + i); + GenericRecord structNotNull = GenericRecord.create(UNDERSCORE_STRUCT_FIELD_TYPE); + structNotNull.setField("_int_field", INT_MIN_VALUE + i); + builder.setField("_struct_not_null", structNotNull); // struct with int + records.add(builder); } + File parquetFile = writeParquetFile("junit", FILE_SCHEMA, records); InputFile inFile = Files.localInput(parquetFile); try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) { assertThat(reader.getRowGroups()).as("Should create only one row group").hasSize(1); @@ -264,6 +262,24 @@ private void createParquetInputFile() throws IOException { parquetFile.deleteOnExit(); } + private File writeParquetFile(String fileName, Schema schema, List records) + throws IOException { + File parquetFile = new File(tempDir, fileName + System.nanoTime()); + + OutputFile outFile = Files.localOutput(parquetFile); + try (FileAppender appender = + Parquet.write(outFile) + .schema(schema) + .createWriterFunc(GenericParquetWriter::create) + .build()) { + for (GenericRecord record : records) { + appender.add(record); + } + } + parquetFile.deleteOnExit(); + return parquetFile; + } + @TestTemplate public void testAllNulls() { boolean shouldRead; @@ -988,6 +1004,65 @@ public void testTransformFilter() { .isTrue(); } + @TestTemplate + public void testVariantFieldMixedValuesNotNull() throws IOException { + assumeThat(format).isEqualTo(FileFormat.PARQUET); + + List records = Lists.newArrayList(); + for (int i = 0; i < 10; i++) { + GenericRecord record = GenericRecord.create(VARIANT_SCHEMA); + record.setField("id", i); + if (i % 2 == 0) { + VariantMetadata metadata = Variants.metadata("field"); + ShreddedObject obj = Variants.object(metadata); + obj.put("field", Variants.of("value" + i)); + Variant variant = Variant.of(metadata, obj); + record.setField("variant_field", variant); + } + records.add(record); + } + + File parquetFile = writeParquetFile("test-variant", VARIANT_SCHEMA, records); + InputFile inFile = Files.localInput(parquetFile); + try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) { + BlockMetaData blockMetaData = reader.getRowGroups().get(0); + MessageType fileSchema = reader.getFileMetaData().getSchema(); + ParquetMetricsRowGroupFilter rowGroupFilter = + new ParquetMetricsRowGroupFilter(VARIANT_SCHEMA, notNull("variant_field"), true); + + assertThat(rowGroupFilter.shouldRead(fileSchema, blockMetaData)) + .as("Should read: variant notNull filters must be evaluated post scan") + .isTrue(); + } + } + + @TestTemplate + public void testVariantFieldAllNullsNotNull() throws IOException { + assumeThat(format).isEqualTo(FileFormat.PARQUET); + + List records = Lists.newArrayListWithExpectedSize(10); + for (int i = 0; i < 10; i++) { + GenericRecord record = GenericRecord.create(VARIANT_SCHEMA); + record.setField("id", i); + record.setField("variant_field", null); + records.add(record); + } + + File parquetFile = writeParquetFile("test-variant-nulls", VARIANT_SCHEMA, records); + InputFile inFile = Files.localInput(parquetFile); + + try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) { + BlockMetaData blockMetaData = reader.getRowGroups().get(0); + MessageType fileSchema = reader.getFileMetaData().getSchema(); + ParquetMetricsRowGroupFilter rowGroupFilter = + new ParquetMetricsRowGroupFilter(VARIANT_SCHEMA, notNull("variant_field"), true); + + assertThat(rowGroupFilter.shouldRead(fileSchema, blockMetaData)) + .as("Should read: variant notNull filters must be evaluated post scan even for all nulls") + .isTrue(); + } + } + private boolean shouldRead(Expression expression) { return shouldRead(expression, true); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java index 1ad346d39ab7..598e5dd23548 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetMetricsRowGroupFilter.java @@ -155,10 +155,11 @@ public Boolean notNull(BoundReference ref) { // if the column has no non-null values, the expression cannot match int id = ref.fieldId(); - // When filtering nested types notNull() is implicit filter passed even though complex - // filters aren't pushed down in Parquet. Leave all nested column type filters to be - // evaluated post scan. - if (schema.findType(id) instanceof Type.NestedType) { + // When filtering nested types or variant types, notNull() is an implicit filter passed + // even though complex filters aren't pushed down in Parquet. Leave these type filters + // to be evaluated post scan. + Type type = schema.findType(id); + if (type instanceof Type.NestedType || type.isVariantType()) { return ROWS_MIGHT_MATCH; } diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java index 9fc71125a92e..2754e891a481 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/SparkTestHelperBase.java @@ -24,6 +24,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.spark.sql.Row; +import org.apache.spark.unsafe.types.VariantVal; public class SparkTestHelperBase { protected static final Object ANY = new Object(); @@ -79,6 +80,13 @@ protected void assertEquals(String context, Object[] expectedRow, Object[] actua } else { assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue); } + } else if (expectedValue instanceof VariantVal && actualValue instanceof VariantVal) { + // Spark VariantVal comparison is based on raw byte[] comparison, which can fail + // if Spark uses trailing null bytes. so, we compare their JSON representation instead. + assertThat(actualValue) + .asString() + .as("%s contents should match (VariantVal JSON)", context) + .isEqualTo((expectedValue).toString()); } else if (expectedValue != ANY) { assertThat(actualValue).as("%s contents should match", context).isEqualTo(expectedValue); } diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java index 9d2ce2b388a2..a984c4c826d2 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestFilterPushDown.java @@ -23,6 +23,8 @@ import static org.assertj.core.api.Assertions.assertThat; import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.sql.Timestamp; import java.time.Instant; import java.util.List; @@ -35,7 +37,12 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.spark.SparkCatalogConfig; import org.apache.iceberg.spark.TestBaseWithCatalog; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.Variants; import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.unsafe.types.VariantVal; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; @@ -578,6 +585,68 @@ public void testFilterPushdownWithSpecialFloatingPointPartitionValues() { ImmutableList.of(row(4, Double.NEGATIVE_INFINITY))); } + @TestTemplate + public void testVariantExtractFiltering() { + sql( + "CREATE TABLE %s (id BIGINT, data VARIANT) USING iceberg TBLPROPERTIES" + + "('format-version'='3')", + tableName); + configurePlanningMode(planningMode); + + sql( + "INSERT INTO %s VALUES " + + "(1, parse_json('{\"field\": \"foo\", \"num\": 25}')), " + + "(2, parse_json('{\"field\": \"bar\", \"num\": 30}')), " + + "(3, parse_json('{\"field\": \"baz\", \"num\": 35}')), " + + "(4, null)", + tableName); + + withDefaultTimeZone( + "UTC", + () -> { + checkFilters( + "try_variant_get(data, '$.num', 'int') IS NOT NULL", + "isnotnull(data) AND isnotnull(try_variant_get(data, $.num, IntegerType, false, Some(UTC)))", + "data IS NOT NULL", + ImmutableList.of( + row(1L, toSparkVariantRow("foo", 25)), + row(2L, toSparkVariantRow("bar", 30)), + row(3L, toSparkVariantRow("baz", 35)))); + + checkFilters( + "try_variant_get(data, '$.num', 'int') IS NULL", + "isnull(try_variant_get(data, $.num, IntegerType, false, Some(UTC)))", + "", + ImmutableList.of(row(4L, null))); + + checkFilters( + "try_variant_get(data, '$.num', 'int') > 30", + "isnotnull(data) AND (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) > 30)", + "data IS NOT NULL", + ImmutableList.of(row(3L, toSparkVariantRow("baz", 35)))); + + checkFilters( + "try_variant_get(data, '$.num', 'int') = 30", + "isnotnull(data) AND (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) = 30)", + "data IS NOT NULL", + ImmutableList.of(row(2L, toSparkVariantRow("bar", 30)))); + + checkFilters( + "try_variant_get(data, '$.num', 'int') IN (25, 35)", + "try_variant_get(data, $.num, IntegerType, false, Some(UTC)) IN (25,35)", + "", + ImmutableList.of( + row(1L, toSparkVariantRow("foo", 25)), row(3L, toSparkVariantRow("baz", 35)))); + + checkFilters( + "try_variant_get(data, '$.num', 'int') != 25", + "isnotnull(data) AND NOT (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) = 25)", + "data IS NOT NULL", + ImmutableList.of( + row(2L, toSparkVariantRow("bar", 30)), row(3L, toSparkVariantRow("baz", 35)))); + }); + } + private void checkOnlyIcebergFilters( String predicate, String icebergFilters, List expectedRows) { @@ -600,7 +669,7 @@ private void checkFilters( if (sparkFilter != null) { assertThat(planAsString) .as("Post scan filter should match") - .contains("Filter (" + sparkFilter + ")"); + .containsAnyOf("Filter (" + sparkFilter + ")", "Filter " + sparkFilter); } else { assertThat(planAsString).as("Should be no post scan filter").doesNotContain("Filter ("); } @@ -613,4 +682,22 @@ private void checkFilters( private Timestamp timestamp(String timestampAsString) { return Timestamp.from(Instant.parse(timestampAsString)); } + + private VariantVal toSparkVariantRow(String field, int num) { + VariantMetadata metadata = Variants.metadata("field", "num"); + + ShreddedObject obj = Variants.object(metadata); + obj.put("field", Variants.of(field)); + obj.put("num", Variants.of(num)); + + ByteBuffer metadataBuffer = + ByteBuffer.allocate(metadata.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN); + metadata.writeTo(metadataBuffer, 0); + + ByteBuffer valueBuffer = ByteBuffer.allocate(obj.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN); + obj.writeTo(valueBuffer, 0); + + return new VariantVal( + ByteBuffers.toByteArray(valueBuffer), ByteBuffers.toByteArray(metadataBuffer)); + } }