diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java index b6b4ffffafac..d52487a4a45c 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java @@ -154,6 +154,27 @@ public static boolean isOptionSchema(Schema schema) { return false; } + /** + * This method decides whether a schema is of type union and is complex union and is optional + * + * Complex union: the number of options in union larger than 2 + * Optional: null is present in union + * + * @param schema input schema + * @return true if schema is complex union and it is optional + */ + public static boolean isOptionalComplexUnion(Schema schema) { + if (schema.getType() == UNION && schema.getTypes().size() > 2) { + for (Schema type : schema.getTypes()) { + if (type.getType() == Schema.Type.NULL) { + return true; + } + } + } + + return false; + } + static Schema toOption(Schema schema) { if (schema.getType() == UNION) { Preconditions.checkArgument(isOptionSchema(schema), diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java index 281d45b51c2a..83a2e9ecf1ba 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java @@ -52,8 +52,21 @@ public static T visit(Schema schema, AvroSchemaVisitor visitor) { case UNION: List types = schema.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (Schema type : types) { - options.add(visit(type, visitor)); + if (AvroSchemaUtil.isOptionSchema(schema)) { + for (Schema type : types) { + options.add(visit(type, visitor)); + } + } else { + // complex union case + int nonNullIdx = 0; + for (Schema type : types) { + if (type.getType() != Schema.Type.NULL) { + options.add(visitWithName("field" + nonNullIdx, type, visitor)); + nonNullIdx += 1; + } else { + options.add(visit(type, visitor)); + } + } } return visitor.union(schema, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java index e6f1c6eb5097..eb78ac6dcff3 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java @@ -79,11 +79,30 @@ private static T visitRecord(Types.StructType struct, Schema record, AvroSch private static T visitUnion(Type type, Schema union, AvroSchemaWithTypeVisitor visitor) { List types = union.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (Schema branch : types) { - if (branch.getType() == Schema.Type.NULL) { - options.add(visit((Type) null, branch, visitor)); - } else { - options.add(visit(type, branch, visitor)); + + // simple union case + if (AvroSchemaUtil.isOptionSchema(union)) { + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); + } else { + options.add(visit(type, branch, visitor)); + } + } + } else { // complex union case + Preconditions.checkArgument(type instanceof Types.StructType, + "Cannot visit invalid Iceberg type: %s for Avro complex union type: %s", type, union); + + List fields = type.asStructType().fields(); + // start index from 1 because 0 is the tag field which doesn't exist in the original Avro schema + int index = 1; + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); + } else { + options.add(visit(fields.get(index).type(), branch, visitor)); + index += 1; + } } } return visitor.union(type, union, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java index 815e6bc5db85..256be6ebc123 100644 --- a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java +++ b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java @@ -19,6 +19,7 @@ package org.apache.iceberg.avro; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -154,16 +155,48 @@ public Schema.Field field(Schema.Field field, Supplier fieldResult) { @Override public Schema union(Schema union, Iterable options) { - Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", union); - Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); - Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); + if (AvroSchemaUtil.isOptionSchema(union)) { + Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); + Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); - if (!Objects.equals(nonNullOriginal, nonNullResult)) { - return AvroSchemaUtil.toOption(nonNullResult); - } + if (!Objects.equals(nonNullOriginal, nonNullResult)) { + return AvroSchemaUtil.toOption(nonNullResult); + } + + return union; + } else { // Complex union + Preconditions.checkArgument(current instanceof Types.StructType, + "Incompatible projected type: %s for Avro complex union type: %s", current, union); + + Types.StructType asStructType = current.asStructType(); + + long nonNullBranchesCount = union.getTypes().stream() + .filter(branch -> branch.getType() != Schema.Type.NULL).count(); + Preconditions.checkState(asStructType.fields().size() > nonNullBranchesCount, + "Column projection on struct converted from Avro complex union type: %s is not supported", union); + + Iterator resultBranchIterator = options.iterator(); + + // we start index from 1 because 0 is the tag field which doesn't exist in the original Avro + int index = 1; + List resultBranches = Lists.newArrayListWithExpectedSize(union.getTypes().size()); - return union; + try { + for (Schema originalBranch : union.getTypes()) { + if (originalBranch.getType() == Schema.Type.NULL) { + resultBranches.add(resultBranchIterator.next()); + } else { + this.current = asStructType.fields().get(index).type(); + resultBranches.add(resultBranchIterator.next()); + index += 1; + } + } + + return Schema.createUnion(resultBranches); + } finally { + this.current = asStructType; + } + } } @Override diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 99855ea050da..c4e0b0f35c7f 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -119,25 +119,27 @@ public Schema record(Schema record, List names, List fields) { @Override public Schema union(Schema union, List options) { - Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", union); - - // only unions with null are allowed, and a null schema results in null - Schema pruned = null; - if (options.get(0) != null) { - pruned = options.get(0); - } else if (options.get(1) != null) { - pruned = options.get(1); - } + if (AvroSchemaUtil.isOptionSchema(union)) { + // case option union + Schema pruned = null; + if (options.get(0) != null) { + pruned = options.get(0); + } else if (options.get(1) != null) { + pruned = options.get(1); + } - if (pruned != null) { - if (!Objects.equals(pruned, AvroSchemaUtil.fromOption(union))) { - return AvroSchemaUtil.toOption(pruned); + if (pruned != null) { + if (!Objects.equals(pruned, AvroSchemaUtil.fromOption(union))) { + return AvroSchemaUtil.toOption(pruned); + } + return union; } - return union; - } - return null; + return null; + } else { + // Complex union case + return copyUnion(union, options); + } } @Override @@ -323,4 +325,19 @@ private static Schema.Field copyField(Schema.Field field, Schema newSchema, Inte private static boolean isOptionSchemaWithNonNullFirstOption(Schema schema) { return AvroSchemaUtil.isOptionSchema(schema) && schema.getTypes().get(0).getType() != Schema.Type.NULL; } + + // for primitive types, the visitResult will be null, we want to reuse the primitive types from the original + // schema, while for nested types, we want to use the visitResult because they have content from the previous + // recursive calls. + private static Schema copyUnion(Schema record, List visitResults) { + List branches = Lists.newArrayListWithExpectedSize(visitResults.size()); + for (int i = 0; i < visitResults.size(); i++) { + if (visitResults.get(i) == null) { + branches.add(record.getTypes().get(i)); + } else { + branches.add(visitResults.get(i)); + } + } + return Schema.createUnion(branches); + } } diff --git a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java index 73c86226007f..770e6f730f5f 100644 --- a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java +++ b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java @@ -93,7 +93,7 @@ public Type record(Schema record, List names, List fieldTypes) { Type fieldType = fieldTypes.get(i); int fieldId = getId(field); - if (AvroSchemaUtil.isOptionSchema(field.schema())) { + if (AvroSchemaUtil.isOptionSchema(field.schema()) || AvroSchemaUtil.isOptionalComplexUnion(field.schema())) { newFields.add(Types.NestedField.optional(fieldId, field.name(), fieldType, field.doc())); } else { newFields.add(Types.NestedField.required(fieldId, field.name(), fieldType, field.doc())); @@ -105,13 +105,27 @@ public Type record(Schema record, List names, List fieldTypes) { @Override public Type union(Schema union, List options) { - Preconditions.checkArgument(AvroSchemaUtil.isOptionSchema(union), - "Unsupported type: non-option union: %s", union); - // records, arrays, and maps will check nullability later - if (options.get(0) == null) { - return options.get(1); + if (AvroSchemaUtil.isOptionSchema(union)) { + // Optional simple union + // records, arrays, and maps will check nullability later + if (options.get(0) == null) { + return options.get(1); + } else { + return options.get(0); + } } else { - return options.get(0); + // Complex union + List newFields = Lists.newArrayList(); + newFields.add(Types.NestedField.required(allocateId(), "tag", Types.IntegerType.get())); + + int tagIndex = 0; + for (Type type : options) { + if (type != null) { + newFields.add(Types.NestedField.optional(allocateId(), "field" + tagIndex++, type)); + } + } + + return Types.StructType.of(newFields); } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java b/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java index 79a54f886c1d..cb212fc6b5ef 100644 --- a/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java +++ b/core/src/test/java/org/apache/iceberg/avro/TestBuildAvroProjection.java @@ -256,4 +256,121 @@ public void projectMapWithLessFieldInValueSchema() { assertEquals("Unexpected value ID discovered on the projected map schema", 1, Integer.valueOf(actual.getProp(AvroSchemaUtil.VALUE_ID_PROP)).intValue()); } + + @Test + public void projectUnionWithBranchSchemaUnchanged() { + + final Type icebergType = Types.StructType.of( + Types.NestedField.required(0, "tag", Types.IntegerType.get()), + Types.NestedField.optional(1, "field0", Types.IntegerType.get()), + Types.NestedField.optional(2, "field1", Types.StringType.get()) + ); + + final org.apache.avro.Schema expected = SchemaBuilder.unionOf() + .intType() + .and() + .stringType() + .endUnion(); + + final BuildAvroProjection testSubject = new BuildAvroProjection(icebergType, Collections.emptyMap()); + + final Iterable branches = expected.getTypes(); + + final org.apache.avro.Schema actual = testSubject.union(expected, branches); + + assertEquals("Union projection produced undesired union schema", + expected, actual); + } + + @Test + public void projectUnionWithTypePromotion() { + + final Type icebergType = Types.StructType.of( + Types.NestedField.required(0, "tag", Types.IntegerType.get()), + Types.NestedField.optional(1, "field0", Types.LongType.get()), + Types.NestedField.optional(2, "field1", Types.StringType.get()) + ); + + final org.apache.avro.Schema originalSchema = SchemaBuilder.unionOf() + .intType() + .and() + .stringType() + .endUnion(); + + // once projected onto iceberg schema, first branch of Avro union schema will be promoted from int to long + final org.apache.avro.Schema expected = SchemaBuilder.unionOf() + .longType() + .and() + .stringType() + .endUnion(); + + final BuildAvroProjection testSubject = new BuildAvroProjection(icebergType, Collections.emptyMap()); + + final Iterable branches = expected.getTypes(); + + final org.apache.avro.Schema actual = testSubject.union(originalSchema, branches); + + assertEquals("Union projection produced undesired union schema", + expected, actual); + } + + @Test + public void projectUnionWithExtraFieldInNestedType() { + + final Type icebergType = Types.StructType.of( + Types.NestedField.required(0, "tag", Types.IntegerType.get()), + Types.NestedField.optional(1, "field0", Types.StringType.get()), + Types.NestedField.optional(2, "field1", Types.StructType.of( + Types.NestedField.optional(3, "c1", Types.IntegerType.get()), + Types.NestedField.optional(4, "c2", Types.StringType.get()), + Types.NestedField.optional(5, "c3", Types.StringType.get()) + )) + ); + + final org.apache.avro.Schema originalSchema = SchemaBuilder.unionOf() + .stringType() + .and() + .record("r") + .fields() + .name("c1") + .type() + .intType() + .noDefault() + .name("c2") + .type() + .stringType() + .noDefault() + .endRecord() + .endUnion(); + + // once projected onto iceberg schema, the avro schema will have an extra string column in struct within union + final org.apache.avro.Schema expected = SchemaBuilder.unionOf() + .stringType() + .and() + .record("r") + .fields() + .name("c1") + .type() + .intType() + .noDefault() + .name("c2") + .type() + .stringType() + .noDefault() + .name("c3") + .type() + .stringType() + .noDefault() + .endRecord() + .endUnion(); + + final BuildAvroProjection testSubject = new BuildAvroProjection(icebergType, Collections.emptyMap()); + + final Iterable branches = expected.getTypes(); + + final org.apache.avro.Schema actual = testSubject.union(originalSchema, branches); + + assertEquals("Union projection produced undesired union schema", + expected, actual); + } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java new file mode 100644 index 000000000000..df778211b470 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.avro; + +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.Test; + +public class TestUnionSchemaConversions { + + @Test + public void testRequiredComplexUnion() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = new org.apache.iceberg.Schema( + Types.NestedField.required(0, "unionCol", Types.StructType.of( + Types.NestedField.required(1, "tag", Types.IntegerType.get()), + Types.NestedField.optional(2, "field0", Types.IntegerType.get()), + Types.NestedField.optional(3, "field1", Types.StringType.get()) + )) + ); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } + + @Test + public void testOptionalComplexUnion() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = new org.apache.iceberg.Schema( + Types.NestedField.optional(0, "unionCol", Types.StructType.of( + Types.NestedField.required(1, "tag", Types.IntegerType.get()), + Types.NestedField.optional(2, "field0", Types.IntegerType.get()), + Types.NestedField.optional(3, "field1", Types.StringType.get()) + )) + ); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } + + @Test + public void testSimpleUnionSchema() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("optionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .nullDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + org.apache.iceberg.Schema expectedIcebergSchema = new org.apache.iceberg.Schema( + Types.NestedField.optional(0, "optionCol", Types.IntegerType.get()) + ); + + Assert.assertEquals(expectedIcebergSchema.toString(), icebergSchema.toString()); + } +} + diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java index c693e2e2c057..cb105195efd3 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -28,6 +28,7 @@ import org.apache.avro.Schema; import org.apache.avro.io.DatumReader; import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; import org.apache.iceberg.avro.SupportsRowPosition; import org.apache.iceberg.avro.ValueReader; @@ -88,7 +89,11 @@ public ValueReader record(Types.StructType expected, Schema record, List union(Type expected, Schema union, List> options) { - return ValueReaders.union(options); + if (AvroSchemaUtil.isOptionSchema(union)) { + return ValueReaders.union(options); + } else { + return SparkValueReaders.complexUnion(union, options); + } } @Override diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index 0d3ce2b28d0b..664f3e35dd9d 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -27,6 +27,8 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import java.util.Objects; +import org.apache.avro.Schema; import org.apache.avro.io.Decoder; import org.apache.avro.util.Utf8; import org.apache.iceberg.avro.ValueReader; @@ -81,6 +83,10 @@ static ValueReader struct(List> readers, Types.Struc return new StructReader(readers, struct, idToConstant); } + static ValueReader complexUnion(Schema schema, List> readers) { + return new ComplexUnionReader(schema, readers); + } + private static class StringReader implements ValueReader { private static final StringReader INSTANCE = new StringReader(); @@ -285,4 +291,52 @@ protected void set(InternalRow struct, int pos, Object value) { } } } + + private static class ComplexUnionReader implements ValueReader { + private final List branches; + private final ValueReader[] readers; + private int nullIndex; + + private ComplexUnionReader(Schema schema, List> readers) { + this.branches = schema.getTypes(); + this.readers = new ValueReader[readers.size()]; + for (int i = 0; i < this.readers.length; i += 1) { + this.readers[i] = readers.get(i); + } + + // Calculate NULL branch if it exists in the union schema + this.nullIndex = Integer.MAX_VALUE; + for (int i = 0; i < branches.size(); i++) { + Schema branch = branches.get(i); + if (Objects.equals(branch.getType(), Schema.Type.NULL)) { + this.nullIndex = i; + break; + } + } + } + + @Override + public InternalRow read(Decoder decoder, Object reuse) throws IOException { + int index = decoder.readIndex(); + if (index == nullIndex) { + // if it is a null data, directly return null as the whole union result + return null; + } + + // otherwise, we need to return an InternalRow as a struct data + InternalRow struct = new GenericInternalRow(nullIndex < Integer.MAX_VALUE ? + branches.size() : branches.size() + 1); + for (int i = 0; i < struct.numFields(); i += 1) { + struct.setNullAt(i); + } + + Object value = readers[index].read(decoder, reuse); + + int outputFieldIndex = nullIndex < index ? index - 1 : index; + struct.setInt(0, outputFieldIndex); + struct.update(outputFieldIndex + 1, value); // add 1 to offset `tag` field + + return struct; + } + } } diff --git a/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java new file mode 100644 index 000000000000..2cf91a175be0 --- /dev/null +++ b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestSparkAvroUnions { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void writeAndValidateRequiredComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(3, rows.get(0).getStruct(0, 3).numFields()); + Assert.assertEquals(1, rows.get(0).getStruct(0, 3).getInt(0)); + Assert.assertTrue(rows.get(0).getStruct(0, 3).isNullAt(1)); + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); + + Assert.assertEquals(3, rows.get(1).getStruct(0, 3).numFields()); + Assert.assertEquals(0, rows.get(1).getStruct(0, 3).getInt(0)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(1).getStruct(0, 3).isNullAt(2)); + } + } + + @Test + public void writeAndValidateOptionalComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .nullDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + GenericData.Record unionRecord3 = new GenericData.Record(avroSchema); + unionRecord3.put("unionCol", null); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + writer.append(unionRecord3); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(2).isNullAt(0)); + } + } + + @Test + public void writeAndValidateSingleTypeUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .nullDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", 0); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(0, rows.get(0).getInt(0)); + Assert.assertEquals(1, rows.get(1).getInt(0)); + } + } + + @Test + public void testDeeplyNestedUnionSchema1() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("col1", Arrays.asList("foo", 1)); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("col1", Arrays.asList(2, "bar")); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from union to struct + Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 3).getString(2)); + } + } + + @Test + public void testDeeplyNestedUnionSchema2() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .record("r1") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .and() + .record("r2") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record outer = new GenericData.Record(avroSchema); + GenericData.Record inner = new GenericData.Record(avroSchema.getFields().get(0).schema() + .getElementType().getTypes().get(0)); + + inner.put("id", 1); + outer.put("col1", Arrays.asList(inner)); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(outer); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from union to struct + Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 3).getStruct(1, 1).getInt(0)); + } + } + + @Test + public void writeAndValidateColumnProjectionInComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "foo"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema).select("unionCol.field0"); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().contains("Column projection")); + } + } +}