diff --git a/core/src/main/java/org/apache/iceberg/avro/Avro.java b/core/src/main/java/org/apache/iceberg/avro/Avro.java index 85cc8d902026..ec93a2004a4b 100644 --- a/core/src/main/java/org/apache/iceberg/avro/Avro.java +++ b/core/src/main/java/org/apache/iceberg/avro/Avro.java @@ -611,6 +611,10 @@ public static class ReadBuilder { private Function> createReaderFunc = null; private BiFunction> createReaderBiFunc = null; + // This field is temporally added to pass the map between Avro schema name and Iceberg field id + // mocked in the test to the classes of ProjectionDatumReader and PruneColumns + private Map avroSchemaNameToIcebergFieldId = null; + @SuppressWarnings("UnnecessaryLambda") private final Function> defaultCreateReaderFunc = readSchema -> { @@ -683,6 +687,12 @@ public ReadBuilder classLoader(ClassLoader classLoader) { return this; } + public ReadBuilder withAvroSchemaNameToIcebergFieldId( + Map schemaNameToIcebergFieldId) { + this.avroSchemaNameToIcebergFieldId = schemaNameToIcebergFieldId; + return this; + } + public AvroIterable build() { Preconditions.checkNotNull(schema, "Schema is required"); Function> readerFunc; @@ -696,7 +706,8 @@ public AvroIterable build() { return new AvroIterable<>( file, - new ProjectionDatumReader<>(readerFunc, schema, renames, nameMapping), + new ProjectionDatumReader<>( + readerFunc, schema, renames, nameMapping, avroSchemaNameToIcebergFieldId), start, length, reuseContainers); diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java index 46c17722f8f7..14c2f655b5aa 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java @@ -27,6 +27,7 @@ import org.apache.avro.Schema; import org.apache.iceberg.mapping.MappedField; import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -47,12 +48,16 @@ private AvroSchemaUtil() {} public static final String ELEMENT_ID_PROP = "element-id"; public static final String ADJUST_TO_UTC_PROP = "adjust-to-utc"; + public static final String BRANCH_ID_PROP = "branch-id"; + private static final Schema NULL = Schema.create(Schema.Type.NULL); private static final Schema.Type MAP = Schema.Type.MAP; private static final Schema.Type ARRAY = Schema.Type.ARRAY; private static final Schema.Type UNION = Schema.Type.UNION; private static final Schema.Type RECORD = Schema.Type.RECORD; + private static final Joiner DOT = Joiner.on('.'); + public static Schema convert(org.apache.iceberg.Schema schema, String tableName) { return convert(schema, ImmutableMap.of(schema.asStruct(), tableName)); } @@ -121,6 +126,15 @@ public static Schema pruneColumns( return new PruneColumns(selectedIds, nameMapping).rootSchema(schema); } + public static Schema pruneColumns( + Schema schema, + Set selectedIds, + NameMapping nameMapping, + Map avroSchemaFieldNameToIcebergFieldId) { + return new PruneColumns(selectedIds, nameMapping, avroSchemaFieldNameToIcebergFieldId) + .rootSchema(schema); + } + public static Schema buildAvroProjection( Schema schema, org.apache.iceberg.Schema expected, Map renames) { return AvroCustomOrderSchemaVisitor.visit(schema, new BuildAvroProjection(expected, renames)); @@ -158,6 +172,27 @@ public static boolean isOptionSchema(Schema schema) { return false; } + /** + * This method decides whether a schema is of type union and is complex union and is optional + * + *

Complex union: the number of options in union larger than 2 Optional: null is present in + * union + * + * @param schema input schema + * @return true if schema is complex union and it is optional + */ + public static boolean isOptionalComplexUnion(Schema schema) { + if (schema.getType() == UNION && schema.getTypes().size() > 2) { + for (Schema type : schema.getTypes()) { + if (type.getType() == Schema.Type.NULL) { + return true; + } + } + } + + return false; + } + static Schema toOption(Schema schema) { if (schema.getType() == UNION) { Preconditions.checkArgument( @@ -168,6 +203,18 @@ static Schema toOption(Schema schema) { } } + public static Schema toOption(Schema schema, boolean nullIsSecondElement) { + if (schema.getType() == UNION) { + Preconditions.checkArgument( + isOptionSchema(schema), "Union schemas are not supported: %s", schema); + return schema; + } else if (nullIsSecondElement) { + return Schema.createUnion(schema, NULL); + } else { + return Schema.createUnion(NULL, schema); + } + } + static Schema fromOption(Schema schema) { Preconditions.checkArgument( schema.getType() == UNION, "Expected union schema but was passed: %s", schema); @@ -477,4 +524,17 @@ private static String sanitize(char character) { } return "_x" + Integer.toHexString(character).toUpperCase(); } + + public static Integer getBranchId( + Schema branch, Map nameToIdMap, Iterable parentFieldNames) { + Object id = branch.getObjectProp(BRANCH_ID_PROP); + if (id != null) { + return toInt(id); + } else if (nameToIdMap != null && nameToIdMap.isEmpty()) { + List names = Lists.newArrayList(parentFieldNames); + names.add(branch.getName()); + return nameToIdMap.get(DOT.join(names)); + } + return null; + } } diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java index f22a3592ad3d..96688c37eea8 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java @@ -51,11 +51,23 @@ public static T visit(Schema schema, AvroSchemaVisitor visitor) { case UNION: List types = schema.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (Schema type : types) { - options.add(visit(type, visitor)); + if (AvroSchemaUtil.isOptionSchema(schema)) { + for (Schema type : types) { + options.add(visit(type, visitor)); + } + } else { + // complex union case + int nonNullIdx = 0; + for (Schema type : types) { + if (type.getType() != Schema.Type.NULL) { + options.add(visitWithName("field" + nonNullIdx, type, visitor)); + nonNullIdx += 1; + } else { + options.add(visit(type, visitor)); + } + } } return visitor.union(schema, options); - case ARRAY: if (schema.getLogicalType() instanceof LogicalMap) { return visitor.array(schema, visit(schema.getElementType(), visitor)); diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java index 85a8718abfce..9802c3adff25 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java @@ -82,11 +82,38 @@ private static T visitRecord( private static T visitUnion(Type type, Schema union, AvroSchemaWithTypeVisitor visitor) { List types = union.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (Schema branch : types) { - if (branch.getType() == Schema.Type.NULL) { - options.add(visit((Type) null, branch, visitor)); - } else { - options.add(visit(type, branch, visitor)); + + // simple union case + if (AvroSchemaUtil.isOptionSchema(union)) { + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); + } else { + options.add(visit(type, branch, visitor)); + } + } + } else { // complex union case + Preconditions.checkArgument( + type instanceof Types.StructType, + "Cannot visit invalid Iceberg type: %s for Avro complex union type: %s", + type, + union); + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); + } else { + Types.NestedField expectedSchemaField = null; + String branchId = branch.getProp(AvroSchemaUtil.BRANCH_ID_PROP); + if (branchId != null) { + expectedSchemaField = type.asStructType().field(Integer.parseInt(branchId)); + } + if (expectedSchemaField != null) { + options.add(visit(expectedSchemaField.type(), branch, visitor)); + } else { + Type pseudoExpectedSchemaField = AvroSchemaUtil.convert(branch); + options.add(visit(pseudoExpectedSchemaField, branch, visitor)); + } + } } } return visitor.union(type, union, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java index 3f1a71a9e6c2..2322ab45a54d 100644 --- a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java +++ b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java @@ -159,17 +159,15 @@ public Schema.Field field(Schema.Field field, Supplier fieldResult) { @Override public Schema union(Schema union, Iterable options) { - Preconditions.checkState( - AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", - union); - Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); - Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); - - if (!Objects.equals(nonNullOriginal, nonNullResult)) { - return AvroSchemaUtil.toOption(nonNullResult); - } + if (AvroSchemaUtil.isOptionSchema(union)) { + Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); + Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); + if (!Objects.equals(nonNullOriginal, nonNullResult)) { + boolean nullIsSecondOption = union.getTypes().get(1).getType() == Schema.Type.NULL; + return AvroSchemaUtil.toOption(nonNullResult, nullIsSecondOption); + } + } return union; } diff --git a/core/src/main/java/org/apache/iceberg/avro/ProjectionDatumReader.java b/core/src/main/java/org/apache/iceberg/avro/ProjectionDatumReader.java index 3b04fe30db65..539fd8d77634 100644 --- a/core/src/main/java/org/apache/iceberg/avro/ProjectionDatumReader.java +++ b/core/src/main/java/org/apache/iceberg/avro/ProjectionDatumReader.java @@ -39,6 +39,8 @@ public class ProjectionDatumReader implements DatumReader, SupportsRowPosi private Schema fileSchema = null; private DatumReader wrapped = null; + private Map avroSchemaNameToIcebergFieldId = null; + public ProjectionDatumReader( Function> getReader, org.apache.iceberg.Schema expectedSchema, @@ -50,6 +52,16 @@ public ProjectionDatumReader( this.nameMapping = nameMapping; } + public ProjectionDatumReader( + Function> getReader, + org.apache.iceberg.Schema expectedSchema, + Map renames, + NameMapping nameMapping, + Map avroSchemaNameToIcebergFieldId) { + this(getReader, expectedSchema, renames, nameMapping); + this.avroSchemaNameToIcebergFieldId = avroSchemaNameToIcebergFieldId; + } + @Override public void setRowPositionSupplier(Supplier posSupplier) { if (wrapped instanceof SupportsRowPosition) { @@ -64,7 +76,9 @@ public void setSchema(Schema newFileSchema) { nameMapping = MappingUtil.create(expectedSchema); } Set projectedIds = TypeUtil.getProjectedIds(expectedSchema); - Schema prunedSchema = AvroSchemaUtil.pruneColumns(newFileSchema, projectedIds, nameMapping); + Schema prunedSchema = + AvroSchemaUtil.pruneColumns( + newFileSchema, projectedIds, nameMapping, avroSchemaNameToIcebergFieldId); this.readSchema = AvroSchemaUtil.buildAvroProjection(prunedSchema, expectedSchema, renames); this.wrapped = newDatumReader(); } diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 2de2c0fe029d..8ceaef4d02ce 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -39,12 +39,22 @@ class PruneColumns extends AvroSchemaVisitor { private final Set selectedIds; private final NameMapping nameMapping; + private Map avroSchemaFieldNameToIcebergFieldId; + PruneColumns(Set selectedIds, NameMapping nameMapping) { Preconditions.checkNotNull(selectedIds, "Selected field ids cannot be null"); this.selectedIds = selectedIds; this.nameMapping = nameMapping; } + PruneColumns( + Set selectedIds, + NameMapping nameMapping, + Map avroSchemaFieldNameToIcebergFieldId) { + this(selectedIds, nameMapping); + this.avroSchemaFieldNameToIcebergFieldId = avroSchemaFieldNameToIcebergFieldId; + } + Schema rootSchema(Schema record) { Schema result = visit(record, this); if (result != null) { @@ -118,27 +128,27 @@ public Schema record(Schema record, List names, List fields) { @Override public Schema union(Schema union, List options) { - Preconditions.checkState( - AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", - union); - - // only unions with null are allowed, and a null schema results in null - Schema pruned = null; - if (options.get(0) != null) { - pruned = options.get(0); - } else if (options.get(1) != null) { - pruned = options.get(1); - } + if (AvroSchemaUtil.isOptionSchema(union)) { + // case option union + Schema pruned = null; + if (options.get(0) != null) { + pruned = options.get(0); + } else if (options.get(1) != null) { + pruned = options.get(1); + } - if (pruned != null) { - if (!Objects.equals(pruned, AvroSchemaUtil.fromOption(union))) { - return AvroSchemaUtil.toOption(pruned); + if (pruned != null) { + if (!Objects.equals(pruned, AvroSchemaUtil.fromOption(union))) { + return AvroSchemaUtil.toOption(pruned); + } + return union; } - return union; - } - return null; + return null; + } else { + // Complex union case + return pruneComplexUnion(union, options); + } } @Override @@ -345,4 +355,32 @@ private static boolean isOptionSchemaWithNonNullFirstOption(Schema schema) { return AvroSchemaUtil.isOptionSchema(schema) && schema.getTypes().get(0).getType() != Schema.Type.NULL; } + + /** + * For primitive types, the visitResult will be null, we want to reuse the primitive types from + * the original schema, while for nested types, we want to use the visitResult because they have + * content from the previous recursive calls. Also the id of the field in the Iceberg schema + * corresponding to each branch schema of the union is assigned as the property named "branch-id" + * of the branch schema. + */ + private Schema pruneComplexUnion(Schema union, List visitResults) { + List branches = Lists.newArrayListWithExpectedSize(visitResults.size()); + + List unionTypes = union.getTypes(); + for (int i = 0; i < visitResults.size(); ++i) { + Schema branchSchema = visitResults.get(i); + if (branchSchema == null) { + branchSchema = unionTypes.get(i); + } + Integer branchId = + AvroSchemaUtil.getBranchId( + branchSchema, avroSchemaFieldNameToIcebergFieldId, fieldNames()); + if (branchId != null) { + branchSchema.addProp(AvroSchemaUtil.BRANCH_ID_PROP, String.valueOf(branchId)); + } + + branches.add(branchSchema); + } + return Schema.createUnion(branches); + } } diff --git a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java index 174d63975195..5e2d9cd9ddad 100644 --- a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java +++ b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java @@ -92,7 +92,8 @@ public Type record(Schema record, List names, List fieldTypes) { Type fieldType = fieldTypes.get(i); int fieldId = getId(field); - if (AvroSchemaUtil.isOptionSchema(field.schema())) { + if (AvroSchemaUtil.isOptionSchema(field.schema()) + || AvroSchemaUtil.isOptionalComplexUnion(field.schema())) { newFields.add(Types.NestedField.optional(fieldId, field.name(), fieldType, field.doc())); } else { newFields.add(Types.NestedField.required(fieldId, field.name(), fieldType, field.doc())); @@ -104,13 +105,27 @@ public Type record(Schema record, List names, List fieldTypes) { @Override public Type union(Schema union, List options) { - Preconditions.checkArgument( - AvroSchemaUtil.isOptionSchema(union), "Unsupported type: non-option union: %s", union); - // records, arrays, and maps will check nullability later - if (options.get(0) == null) { - return options.get(1); + if (AvroSchemaUtil.isOptionSchema(union)) { + // Optional simple union + // records, arrays, and maps will check nullability later + if (options.get(0) == null) { + return options.get(1); + } else { + return options.get(0); + } } else { - return options.get(0); + // Complex union + List newFields = Lists.newArrayList(); + newFields.add(Types.NestedField.required(allocateId(), "tag", Types.IntegerType.get())); + + int tagIndex = 0; + for (Type type : options) { + if (type != null) { + newFields.add(Types.NestedField.optional(allocateId(), "field" + tagIndex++, type)); + } + } + + return Types.StructType.of(newFields); } } diff --git a/core/src/main/java/org/apache/iceberg/avro/ValueReaders.java b/core/src/main/java/org/apache/iceberg/avro/ValueReaders.java index 19789cce82fc..cd10263f8223 100644 --- a/core/src/main/java/org/apache/iceberg/avro/ValueReaders.java +++ b/core/src/main/java/org/apache/iceberg/avro/ValueReaders.java @@ -25,6 +25,7 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.Arrays; import java.util.Collection; import java.util.Deque; import java.util.Iterator; @@ -43,6 +44,7 @@ import org.apache.iceberg.common.DynConstructors; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.UUIDUtil; @@ -752,4 +754,75 @@ public Long read(Decoder ignored, Object reuse) throws IOException { return currentPosition; } } + + public static class ComplexUnionReader implements ValueReader { + private static final String UNION_TAG_FIELD_NAME = "tag"; + private final ValueReader[] readers; + private final int[] projectedFieldIdsToIdxInReturnedRow; + private boolean isTagFieldProjected; + private int numOfFieldsInReturnedRow; + private int nullTypeIndex; + + public ValueReader[] getReaders() { + return readers; + } + + public int[] getProjectedFieldIdsToIdxInReturnedRow() { + return projectedFieldIdsToIdxInReturnedRow; + } + + public boolean isTagFieldProjected() { + return isTagFieldProjected; + } + + public int getNumOfFieldsInReturnedRow() { + return numOfFieldsInReturnedRow; + } + + public int getNullTypeIndex() { + return nullTypeIndex; + } + + public ComplexUnionReader(List> readers, Type expected) { + this.readers = new ValueReader[readers.size()]; + for (int i = 0; i < this.readers.length; i += 1) { + this.readers[i] = readers.get(i); + } + + // checking if NULL type exists in Avro union schema + this.nullTypeIndex = -1; + for (int i = 0; i < this.readers.length; i++) { + if (this.readers[i] instanceof ValueReaders.NullReader) { + this.nullTypeIndex = i; + break; + } + } + + // Creating an integer array to track the mapping between the index of fields to be projected + // and the index of the value for the field stored in the returned row, + // if the value for a field equals to -1, it means the value of this field should not be + // stored in the returned row + int numberOfTypes = this.nullTypeIndex == -1 ? this.readers.length : this.readers.length - 1; + this.projectedFieldIdsToIdxInReturnedRow = new int[numberOfTypes]; + Arrays.fill(this.projectedFieldIdsToIdxInReturnedRow, -1); + this.numOfFieldsInReturnedRow = 0; + this.isTagFieldProjected = false; + for (Types.NestedField expectedStructField : expected.asStructType().fields()) { + String fieldName = expectedStructField.name(); + if (fieldName.equals(UNION_TAG_FIELD_NAME)) { + this.isTagFieldProjected = true; + this.numOfFieldsInReturnedRow++; + } else { + int projectedFieldIndex = Integer.valueOf(fieldName.substring(5)); + this.projectedFieldIdsToIdxInReturnedRow[projectedFieldIndex] = + this.numOfFieldsInReturnedRow++; + } + } + } + + @Override + public T read(Decoder decoder, Object reuse) throws IOException { + throw new UnsupportedOperationException(); + } + } } diff --git a/core/src/main/java/org/apache/iceberg/mapping/MappingUtil.java b/core/src/main/java/org/apache/iceberg/mapping/MappingUtil.java index de6ce2ad0425..3c3d4667280c 100644 --- a/core/src/main/java/org/apache/iceberg/mapping/MappingUtil.java +++ b/core/src/main/java/org/apache/iceberg/mapping/MappingUtil.java @@ -52,6 +52,9 @@ public static NameMapping create(Schema schema) { return new NameMapping(TypeUtil.visit(schema, CreateMapping.INSTANCE)); } + public static NameMapping createWithTypeNameForUnionBranch(Schema schema) { + return new NameMapping(TypeUtil.visit(schema, CreateMappingWithTypeNameForUnion.INSTANCE)); + } /** * Update a name-based mapping using changes to a schema. * @@ -307,4 +310,30 @@ public MappedFields primitive(Type.PrimitiveType primitive) { return null; // no mapping because primitives have no nested fields } } + + private static class CreateMappingWithTypeNameForUnion extends CreateMapping { + private static final CreateMappingWithTypeNameForUnion INSTANCE = + new CreateMappingWithTypeNameForUnion(); + + private CreateMappingWithTypeNameForUnion() {} + + @Override + public MappedFields struct(Types.StructType struct, List fieldResults) { + List fields = Lists.newArrayListWithExpectedSize(fieldResults.size()); + boolean isUnion = struct.fields().stream().anyMatch(a -> "tag".equals(a.name())); + + for (int i = 0; i < fieldResults.size(); i += 1) { + Types.NestedField field = struct.fields().get(i); + MappedFields result = fieldResults.get(i); + + fields.add( + MappedField.of( + field.fieldId(), + !isUnion | "tag".equals(field.name()) ? field.name() : field.type().toString(), + result)); + } + + return MappedFields.of(fields); + } + } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java b/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java index edee46685e32..106f06f5c455 100644 --- a/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java +++ b/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java @@ -410,4 +410,26 @@ public void projectMapWithLessFieldInValueSchema() { 1, Integer.valueOf(actual.getProp(AvroSchemaUtil.VALUE_ID_PROP)).intValue()); } + + @Test + public void projectUnionWithBranchSchemaUnchanged() { + + final Type icebergType = + Types.StructType.of( + Types.NestedField.required(0, "tag", Types.IntegerType.get()), + Types.NestedField.optional(1, "field0", Types.IntegerType.get()), + Types.NestedField.optional(2, "field1", Types.StringType.get())); + + final org.apache.avro.Schema expected = + SchemaBuilder.unionOf().intType().and().stringType().endUnion(); + + final BuildAvroProjection testSubject = + new BuildAvroProjection(icebergType, Collections.emptyMap()); + + final Iterable branches = expected.getTypes(); + + final org.apache.avro.Schema actual = testSubject.union(expected, branches); + + assertEquals("Union projection produced undesired union schema", expected, actual); + } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java new file mode 100644 index 000000000000..70e99fb68739 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.avro; + +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Test; + +public class TestUnionSchemaConversions { + + @Test + public void testRequiredComplexUnion() { + Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = + new org.apache.iceberg.Schema( + Types.NestedField.required( + 0, + "unionCol", + Types.StructType.of( + Types.NestedField.required(1, "tag", Types.IntegerType.get()), + Types.NestedField.optional(2, "field0", Types.IntegerType.get()), + Types.NestedField.optional(3, "field1", Types.StringType.get())))); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } + + @Test + public void testOptionalComplexUnion() { + Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = + new org.apache.iceberg.Schema( + Types.NestedField.optional( + 0, + "unionCol", + Types.StructType.of( + Types.NestedField.required(1, "tag", Types.IntegerType.get()), + Types.NestedField.optional(2, "field0", Types.IntegerType.get()), + Types.NestedField.optional(3, "field1", Types.StringType.get())))); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } + + @Test + public void testSimpleUnionSchema() { + Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("optionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .nullDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = + new org.apache.iceberg.Schema( + Types.NestedField.optional(0, "optionCol", Types.IntegerType.get())); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java index 4622d2928ac4..56de92275316 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -27,6 +27,7 @@ import org.apache.avro.Schema; import org.apache.avro.io.DatumReader; import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; import org.apache.iceberg.avro.SupportsRowPosition; import org.apache.iceberg.avro.ValueReader; @@ -88,7 +89,11 @@ public ValueReader record( @Override public ValueReader union(Type expected, Schema union, List> options) { - return ValueReaders.union(options); + if (AvroSchemaUtil.isOptionSchema(union)) { + return ValueReaders.union(options); + } else { + return SparkValueReaders.complexUnion(options, expected); + } } @Override diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index 11655c72d857..9a923a83149d 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -31,6 +31,7 @@ import org.apache.iceberg.avro.ValueReader; import org.apache.iceberg.avro.ValueReaders; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.UUIDUtil; import org.apache.spark.sql.catalyst.InternalRow; @@ -79,6 +80,16 @@ static ValueReader struct( return new StructReader(readers, struct, idToConstant); } + /** + * In case of complex union, table schema (i.e. Iceberg schema) can possibly be pruned because of + * the column projection, while file schema (i.e. Avro schema) can not be pruned to make the data + * read from the file successfully. Therefore, table schema needs to be passed to + * ComplexUnionReader to help it only return the data of the projected type in the union. + */ + static ValueReader complexUnion(List> readers, Type expected) { + return new ComplexUnionReader(readers, expected); + } + private static class StringReader implements ValueReader { private static final StringReader INSTANCE = new StringReader(); @@ -285,4 +296,46 @@ protected void set(InternalRow struct, int pos, Object value) { } } } + + static class ComplexUnionReader extends ValueReaders.ComplexUnionReader { + protected ComplexUnionReader(List> readers, Type expected) { + super(readers, expected); + } + + @Override + public InternalRow read(Decoder decoder, Object reuse) throws IOException { + InternalRow row = reuseOrCreate(reuse); + + int index = decoder.readIndex(); + if (index != super.getNullTypeIndex()) { + int fieldIndex = + (super.getNullTypeIndex() < 0 || index < super.getNullTypeIndex()) ? index : index - 1; + if (super.isTagFieldProjected()) { + row.setInt(0, fieldIndex); + } + + Object value = super.getReaders()[index].read(decoder, reuse); + if (super.getProjectedFieldIdsToIdxInReturnedRow()[fieldIndex] != -1) { + row.update(super.getProjectedFieldIdsToIdxInReturnedRow()[fieldIndex], value); + } + } else { + super.getReaders()[index].read(decoder, reuse); + } + + return row; + } + + private InternalRow reuseOrCreate(Object reuse) { + InternalRow row; + if (reuse instanceof InternalRow) { + row = (InternalRow) reuse; + } else { + row = new GenericInternalRow(super.getNumOfFieldsInReturnedRow()); + } + for (int i = 0; i < row.numFields(); ++i) { + row.setNullAt(i); + } + return row; + } + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java new file mode 100644 index 000000000000..37526a9557a7 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkAvroUnions { + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void writeAndValidateRequiredComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + Map avroSchemaNameToIcebergFieldId = + new HashMap() { + { + put("unionCol", 0); + put("unionCol.tag", 1); + put("unionCol.int", 2); + put("unionCol.string", 3); + } + }; + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .withAvroSchemaNameToIcebergFieldId(avroSchemaNameToIcebergFieldId) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(3, rows.get(0).getStruct(0, 3).numFields()); + Assert.assertEquals(1, rows.get(0).getStruct(0, 3).getInt(0)); + Assert.assertTrue(rows.get(0).getStruct(0, 3).isNullAt(1)); + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); + + Assert.assertEquals(3, rows.get(1).getStruct(0, 3).numFields()); + Assert.assertEquals(0, rows.get(1).getStruct(0, 3).getInt(0)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(1).getStruct(0, 3).isNullAt(2)); + } + } + + @Test + public void writeAndValidateOptionalComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .nullDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + GenericData.Record unionRecord3 = new GenericData.Record(avroSchema); + unionRecord3.put("unionCol", null); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + writer.append(unionRecord3); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + Map avroSchemaNameToIcebergFieldId = + new HashMap() { + { + put("unionCol.tag", 1); + put("unionCol.int", 2); + put("unionCol.string", 3); + put("unionCol", 0); + } + }; + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .withAvroSchemaNameToIcebergFieldId(avroSchemaNameToIcebergFieldId) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(2).getStruct(0, 3).isNullAt(0)); + Assert.assertTrue(rows.get(2).getStruct(0, 3).isNullAt(1)); + Assert.assertTrue(rows.get(2).getStruct(0, 3).isNullAt(2)); + } + } + + @Test + public void writeAndValidateSingleTypeUnion() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .nullDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", 0); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(0, rows.get(0).getInt(0)); + Assert.assertEquals(1, rows.get(1).getInt(0)); + } + } + + @Test + public void testDeeplyNestedUnionSchema1() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("col1", Arrays.asList("foo", 1)); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("col1", Arrays.asList(2, "bar")); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + Map avroSchemaNameToIcebergFieldId = + new HashMap() { + { + put("col1", 0); + put("col1.element", 4); + put("col1.element.string", 3); + put("col1.element.int", 2); + put("col1.element.tag", 1); + } + }; + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .withAvroSchemaNameToIcebergFieldId(avroSchemaNameToIcebergFieldId) + .build()) { + rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from + // union to struct + Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 3).getString(2)); + } + } + + @Test + public void testDeeplyNestedUnionSchema2() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .record("r1") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .and() + .record("r2") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record outer = new GenericData.Record(avroSchema); + GenericData.Record inner = + new GenericData.Record( + avroSchema.getFields().get(0).schema().getElementType().getTypes().get(0)); + + inner.put("id", 1); + outer.put("col1", Arrays.asList(inner)); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(outer); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + Map avroSchemaNameToIcebergFieldId = + new HashMap() { + { + put("col1", 0); + put("col1.element.field0.id", 1); + put("col1.element.field1.id", 2); + put("col1.element.tag", 3); + put("col1.element.field0", 4); + put("col1.element.field1", 5); + put("col1.element", 6); + } + }; + + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .withAvroSchemaNameToIcebergFieldId(avroSchemaNameToIcebergFieldId) + .build()) { + rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from + // union to struct + Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 3).getStruct(1, 1).getInt(0)); + } + } + + @Test + public void writeAndValidateRequiredComplexUnionWithProjection() throws IOException { + org.apache.avro.Schema avroSchema = + SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = + new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema).select("unionCol.field0"); + Map avroSchemaNameToIcebergFieldId = + new HashMap() { + { + put("unionCol", 0); + put("unionCol.field0", 2); + } + }; + List rows; + try (AvroIterable reader = + Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .withAvroSchemaNameToIcebergFieldId(avroSchemaNameToIcebergFieldId) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(1, rows.get(0).getStruct(0, 1).numFields()); + Assert.assertTrue(rows.get(0).getStruct(0, 1).isNullAt(0)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 1).getInt(0)); + } + } +}