Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.iceberg.data;

import static org.apache.iceberg.avro.AvroSchemaUtil.convert;
import static org.apache.iceberg.expressions.Expressions.and;
import static org.apache.iceberg.expressions.Expressions.equal;
import static org.apache.iceberg.expressions.Expressions.greaterThan;
Expand Down Expand Up @@ -49,8 +48,6 @@
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import org.apache.avro.generic.GenericData.Record;
import org.apache.avro.generic.GenericRecordBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.iceberg.FileFormat;
Expand All @@ -59,9 +56,9 @@
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Parameters;
import org.apache.iceberg.Schema;
import org.apache.iceberg.avro.AvroSchemaUtil;
import org.apache.iceberg.data.orc.GenericOrcReader;
import org.apache.iceberg.data.orc.GenericOrcWriter;
import org.apache.iceberg.data.parquet.GenericParquetWriter;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.io.CloseableIterable;
Expand All @@ -78,6 +75,10 @@
import org.apache.iceberg.types.Types.FloatType;
import org.apache.iceberg.types.Types.IntegerType;
import org.apache.iceberg.types.Types.StringType;
import org.apache.iceberg.variants.ShreddedObject;
import org.apache.iceberg.variants.Variant;
import org.apache.iceberg.variants.VariantMetadata;
import org.apache.iceberg.variants.Variants;
import org.apache.orc.OrcFile;
import org.apache.orc.Reader;
import org.apache.parquet.hadoop.ParquetFileReader;
Expand Down Expand Up @@ -138,6 +139,11 @@ public static List<Object> parameters() {
optional(16, "_no_nans", Types.DoubleType.get()),
optional(17, "_some_double_nans", Types.DoubleType.get()));

private static final Schema VARIANT_SCHEMA =
new Schema(
required(1, "id", IntegerType.get()),
optional(2, "variant_field", Types.VariantType.get()));

private static final String TOO_LONG_FOR_STATS_PARQUET;

static {
Expand Down Expand Up @@ -220,40 +226,32 @@ public void createOrcInputFile() throws IOException {
}

private void createParquetInputFile() throws IOException {
File parquetFile = new File(tempDir, "junit" + System.nanoTime());

// build struct field schema
org.apache.avro.Schema structSchema = AvroSchemaUtil.convert(UNDERSCORE_STRUCT_FIELD_TYPE);

OutputFile outFile = Files.localOutput(parquetFile);
try (FileAppender<Record> appender = Parquet.write(outFile).schema(FILE_SCHEMA).build()) {
GenericRecordBuilder builder = new GenericRecordBuilder(convert(FILE_SCHEMA, "table"));
// create 50 records
for (int i = 0; i < INT_MAX_VALUE - INT_MIN_VALUE + 1; i += 1) {
builder.set("_id", INT_MIN_VALUE + i); // min=30, max=79, num-nulls=0
builder.set(
"_no_stats_parquet",
TOO_LONG_FOR_STATS_PARQUET); // value longer than 4k will produce no stats
// in Parquet
builder.set("_required", "req"); // required, always non-null
builder.set("_all_nulls", null); // never non-null
builder.set("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values
builder.set("_no_nulls", ""); // optional, but always non-null
builder.set("_all_nans", Double.NaN); // never non-nan
builder.set("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values
builder.set(
"_some_double_nans", (i % 10 == 0) ? Double.NaN : 2D); // includes some nan values
builder.set("_no_nans", 3D); // optional, but always non-nan
builder.set("_str", i + "str" + i);

Record structNotNull = new Record(structSchema);
structNotNull.put("_int_field", INT_MIN_VALUE + i);
builder.set("_struct_not_null", structNotNull); // struct with int

appender.add(builder.build());
}
List<GenericRecord> records = Lists.newArrayList();

for (int i = 0; i < INT_MAX_VALUE - INT_MIN_VALUE + 1; i += 1) {
GenericRecord builder = GenericRecord.create(FILE_SCHEMA);
builder.setField("_id", INT_MIN_VALUE + i); // min=30, max=79, num-nulls=0
builder.setField(
"_no_stats_parquet",
TOO_LONG_FOR_STATS_PARQUET); // value longer than 4k will produce no stats
// in Parquet
builder.setField("_required", "req"); // required, always non-null
builder.setField("_all_nulls", null); // never non-null
builder.setField("_some_nulls", (i % 10 == 0) ? null : "some"); // includes some null values
builder.setField("_no_nulls", ""); // optional, but always non-null
builder.setField("_all_nans", Double.NaN); // never non-nan
builder.setField("_some_nans", (i % 10 == 0) ? Float.NaN : 2F); // includes some nan values
builder.setField(
"_some_double_nans", (i % 10 == 0) ? Double.NaN : 2D); // includes some nan values
builder.setField("_no_nans", 3D); // optional, but always non-nan
builder.setField("_str", i + "str" + i);
GenericRecord structNotNull = GenericRecord.create(UNDERSCORE_STRUCT_FIELD_TYPE);
structNotNull.setField("_int_field", INT_MIN_VALUE + i);
builder.setField("_struct_not_null", structNotNull); // struct with int
records.add(builder);
}

File parquetFile = writeParquetFile("junit", FILE_SCHEMA, records);
InputFile inFile = Files.localInput(parquetFile);
try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) {
assertThat(reader.getRowGroups()).as("Should create only one row group").hasSize(1);
Expand All @@ -264,6 +262,24 @@ private void createParquetInputFile() throws IOException {
parquetFile.deleteOnExit();
}

private File writeParquetFile(String fileName, Schema schema, List<GenericRecord> records)
throws IOException {
File parquetFile = new File(tempDir, fileName + System.nanoTime());

OutputFile outFile = Files.localOutput(parquetFile);
try (FileAppender<GenericRecord> appender =
Parquet.write(outFile)
.schema(schema)
.createWriterFunc(GenericParquetWriter::create)
.build()) {
for (GenericRecord record : records) {
appender.add(record);
}
}
parquetFile.deleteOnExit();
return parquetFile;
}

@TestTemplate
public void testAllNulls() {
boolean shouldRead;
Expand Down Expand Up @@ -988,6 +1004,65 @@ public void testTransformFilter() {
.isTrue();
}

@TestTemplate
public void testVariantFieldMixedValuesNotNull() throws IOException {
assumeThat(format).isEqualTo(FileFormat.PARQUET);

List<GenericRecord> records = Lists.newArrayList();
for (int i = 0; i < 10; i++) {
GenericRecord record = GenericRecord.create(VARIANT_SCHEMA);
record.setField("id", i);
if (i % 2 == 0) {
VariantMetadata metadata = Variants.metadata("field");
ShreddedObject obj = Variants.object(metadata);
obj.put("field", Variants.of("value" + i));
Variant variant = Variant.of(metadata, obj);
record.setField("variant_field", variant);
}
records.add(record);
}

File parquetFile = writeParquetFile("test-variant", VARIANT_SCHEMA, records);
InputFile inFile = Files.localInput(parquetFile);
try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) {
BlockMetaData blockMetaData = reader.getRowGroups().get(0);
MessageType fileSchema = reader.getFileMetaData().getSchema();
ParquetMetricsRowGroupFilter rowGroupFilter =
new ParquetMetricsRowGroupFilter(VARIANT_SCHEMA, notNull("variant_field"), true);

assertThat(rowGroupFilter.shouldRead(fileSchema, blockMetaData))
.as("Should read: variant notNull filters must be evaluated post scan")
.isTrue();
}
}

@TestTemplate
public void testVariantFieldAllNullsNotNull() throws IOException {
assumeThat(format).isEqualTo(FileFormat.PARQUET);

List<GenericRecord> records = Lists.newArrayListWithExpectedSize(10);
for (int i = 0; i < 10; i++) {
GenericRecord record = GenericRecord.create(VARIANT_SCHEMA);
record.setField("id", i);
record.setField("variant_field", null);
records.add(record);
}

File parquetFile = writeParquetFile("test-variant-nulls", VARIANT_SCHEMA, records);
InputFile inFile = Files.localInput(parquetFile);

try (ParquetFileReader reader = ParquetFileReader.open(parquetInputFile(inFile))) {
BlockMetaData blockMetaData = reader.getRowGroups().get(0);
MessageType fileSchema = reader.getFileMetaData().getSchema();
ParquetMetricsRowGroupFilter rowGroupFilter =
new ParquetMetricsRowGroupFilter(VARIANT_SCHEMA, notNull("variant_field"), true);

assertThat(rowGroupFilter.shouldRead(fileSchema, blockMetaData))
.as("Should read: variant notNull filters must be evaluated post scan even for all nulls")
.isTrue();
}
}

private boolean shouldRead(Expression expression) {
return shouldRead(expression, true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ public <T> Boolean notNull(BoundReference<T> ref) {
// if the column has no non-null values, the expression cannot match
int id = ref.fieldId();

// When filtering nested types notNull() is implicit filter passed even though complex
// filters aren't pushed down in Parquet. Leave all nested column type filters to be
// evaluated post scan.
if (schema.findType(id) instanceof Type.NestedType) {
// When filtering nested types or variant types, notNull() is an implicit filter passed
// even though complex filters aren't pushed down in Parquet. Leave these type filters
// to be evaluated post scan.
Type type = schema.findType(id);
if (type instanceof Type.NestedType || type.isVariantType()) {
return ROWS_MIGHT_MATCH;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.sql.Row;
import org.apache.spark.unsafe.types.VariantVal;

public class SparkTestHelperBase {
protected static final Object ANY = new Object();
Expand Down Expand Up @@ -79,6 +80,13 @@ protected void assertEquals(String context, Object[] expectedRow, Object[] actua
} else {
assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue);
}
} else if (expectedValue instanceof VariantVal && actualValue instanceof VariantVal) {
// Spark VariantVal comparison is based on raw byte[] comparison, which can fail
// if Spark uses trailing null bytes. so, we compare their JSON representation instead.
assertThat(actualValue)
.asString()
.as("%s contents should match (VariantVal JSON)", context)
.isEqualTo((expectedValue).toString());
} else if (expectedValue != ANY) {
assertThat(actualValue).as("%s contents should match", context).isEqualTo(expectedValue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.List;
Expand All @@ -35,7 +37,12 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.iceberg.spark.TestBaseWithCatalog;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.variants.ShreddedObject;
import org.apache.iceberg.variants.VariantMetadata;
import org.apache.iceberg.variants.Variants;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.unsafe.types.VariantVal;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -578,6 +585,68 @@ public void testFilterPushdownWithSpecialFloatingPointPartitionValues() {
ImmutableList.of(row(4, Double.NEGATIVE_INFINITY)));
}

@TestTemplate
public void testVariantExtractFiltering() {
sql(
"CREATE TABLE %s (id BIGINT, data VARIANT) USING iceberg TBLPROPERTIES"
+ "('format-version'='3')",
tableName);
configurePlanningMode(planningMode);

sql(
"INSERT INTO %s VALUES "
+ "(1, parse_json('{\"field\": \"foo\", \"num\": 25}')), "
+ "(2, parse_json('{\"field\": \"bar\", \"num\": 30}')), "
+ "(3, parse_json('{\"field\": \"baz\", \"num\": 35}')), "
+ "(4, null)",
tableName);

withDefaultTimeZone(
"UTC",
() -> {
checkFilters(
"try_variant_get(data, '$.num', 'int') IS NOT NULL",
"isnotnull(data) AND isnotnull(try_variant_get(data, $.num, IntegerType, false, Some(UTC)))",
"data IS NOT NULL",
ImmutableList.of(
row(1L, toSparkVariantRow("foo", 25)),
row(2L, toSparkVariantRow("bar", 30)),
row(3L, toSparkVariantRow("baz", 35))));

checkFilters(
"try_variant_get(data, '$.num', 'int') IS NULL",
"isnull(try_variant_get(data, $.num, IntegerType, false, Some(UTC)))",
"",
ImmutableList.of(row(4L, null)));

checkFilters(
"try_variant_get(data, '$.num', 'int') > 30",
"isnotnull(data) AND (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) > 30)",
"data IS NOT NULL",
ImmutableList.of(row(3L, toSparkVariantRow("baz", 35))));

checkFilters(
"try_variant_get(data, '$.num', 'int') = 30",
"isnotnull(data) AND (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) = 30)",
"data IS NOT NULL",
ImmutableList.of(row(2L, toSparkVariantRow("bar", 30))));

checkFilters(
"try_variant_get(data, '$.num', 'int') IN (25, 35)",
"try_variant_get(data, $.num, IntegerType, false, Some(UTC)) IN (25,35)",
"",
ImmutableList.of(
row(1L, toSparkVariantRow("foo", 25)), row(3L, toSparkVariantRow("baz", 35))));

checkFilters(
"try_variant_get(data, '$.num', 'int') != 25",
"isnotnull(data) AND NOT (try_variant_get(data, $.num, IntegerType, false, Some(UTC)) = 25)",
"data IS NOT NULL",
ImmutableList.of(
row(2L, toSparkVariantRow("bar", 30)), row(3L, toSparkVariantRow("baz", 35))));
});
}

private void checkOnlyIcebergFilters(
String predicate, String icebergFilters, List<Object[]> expectedRows) {

Expand All @@ -600,7 +669,7 @@ private void checkFilters(
if (sparkFilter != null) {
assertThat(planAsString)
.as("Post scan filter should match")
.contains("Filter (" + sparkFilter + ")");
.containsAnyOf("Filter (" + sparkFilter + ")", "Filter " + sparkFilter);
} else {
assertThat(planAsString).as("Should be no post scan filter").doesNotContain("Filter (");
}
Expand All @@ -613,4 +682,22 @@ private void checkFilters(
private Timestamp timestamp(String timestampAsString) {
return Timestamp.from(Instant.parse(timestampAsString));
}

private VariantVal toSparkVariantRow(String field, int num) {
VariantMetadata metadata = Variants.metadata("field", "num");

ShreddedObject obj = Variants.object(metadata);
obj.put("field", Variants.of(field));
obj.put("num", Variants.of(num));

ByteBuffer metadataBuffer =
ByteBuffer.allocate(metadata.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
metadata.writeTo(metadataBuffer, 0);

ByteBuffer valueBuffer = ByteBuffer.allocate(obj.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN);
obj.writeTo(valueBuffer, 0);

return new VariantVal(
ByteBuffers.toByteArray(valueBuffer), ByteBuffers.toByteArray(metadataBuffer));
}
}