diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java index 1b1be01f41..b3ef5eb156 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java @@ -135,6 +135,27 @@ public static boolean isOptionSchema(Schema schema) { return false; } + /** + * This method decides whether a schema is of type union and is complex union and is optional + * + * Complex union: the number of options in union not equals to 2 + * Optional: null is present in union + * + * @param schema input schema + * @return true if schema is complex union and it is optional + */ + public static boolean isOptionalComplexUnion(Schema schema) { + if (schema.getType() == UNION && schema.getTypes().size() != 2) { + for (Schema type : schema.getTypes()) { + if (type.getType() == Schema.Type.NULL) { + return true; + } + } + } + + return false; + } + public static Schema toOption(Schema schema) { if (schema.getType() == UNION) { Preconditions.checkArgument(isOptionSchema(schema), diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java index e6f1c6eb50..4c8aff64be 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java @@ -79,11 +79,18 @@ private static T visitRecord(Types.StructType struct, Schema record, AvroSch private static T visitUnion(Type type, Schema union, AvroSchemaWithTypeVisitor visitor) { List types = union.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); + + int index = 0; for (Schema branch : types) { if (branch.getType() == Schema.Type.NULL) { options.add(visit((Type) null, branch, visitor)); } else { - options.add(visit(type, branch, visitor)); + if (AvroSchemaUtil.isOptionSchema(union)) { + options.add(visit(type, branch, visitor)); + } else { + options.add(visit(type.asStructType().fields().get(index).type(), branch, visitor)); + } + index++; } } return visitor.union(type, union, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java index 5a954103ae..ecdfc34c9d 100644 --- a/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java +++ b/core/src/main/java/org/apache/iceberg/avro/BuildAvroProjection.java @@ -148,13 +148,13 @@ public Schema.Field field(Schema.Field field, Supplier fieldResult) { @Override public Schema union(Schema union, Iterable options) { - Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", union); - Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); - Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); + if (AvroSchemaUtil.isOptionSchema(union)) { + Schema nonNullOriginal = AvroSchemaUtil.fromOption(union); + Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options)); - if (nonNullOriginal != nonNullResult) { - return AvroSchemaUtil.toOption(nonNullResult); + if (nonNullOriginal != nonNullResult) { + return AvroSchemaUtil.toOption(nonNullResult); + } } return union; diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 6176070547..828883ef32 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -106,25 +106,27 @@ public Schema record(Schema record, List names, List fields) { @Override public Schema union(Schema union, List options) { - Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union), - "Invalid schema: non-option unions are not supported: %s", union); - - // only unions with null are allowed, and a null schema results in null - Schema pruned = null; - if (options.get(0) != null) { - pruned = options.get(0); - } else if (options.get(1) != null) { - pruned = options.get(1); - } + if (AvroSchemaUtil.isOptionSchema(union)) { + // case option union + Schema pruned = null; + if (options.get(0) != null) { + pruned = options.get(0); + } else if (options.get(1) != null) { + pruned = options.get(1); + } - if (pruned != null) { - if (pruned != AvroSchemaUtil.fromOption(union)) { - return AvroSchemaUtil.toOption(pruned); + if (pruned != null) { + if (pruned != AvroSchemaUtil.fromOption(union)) { + return AvroSchemaUtil.toOption(pruned); + } + return union; } + + return null; + } else { + // Complex union case return union; } - - return null; } @Override diff --git a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java index 577fa2330c..933042c449 100644 --- a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java +++ b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java @@ -92,7 +92,7 @@ public Type record(Schema record, List names, List fieldTypes) { Type fieldType = fieldTypes.get(i); int fieldId = getId(field); - if (AvroSchemaUtil.isOptionSchema(field.schema())) { + if (AvroSchemaUtil.isOptionSchema(field.schema()) || AvroSchemaUtil.isOptionalComplexUnion(field.schema())) { newFields.add(Types.NestedField.optional(fieldId, field.name(), fieldType, field.doc())); } else { newFields.add(Types.NestedField.required(fieldId, field.name(), fieldType, field.doc())); @@ -104,13 +104,26 @@ public Type record(Schema record, List names, List fieldTypes) { @Override public Type union(Schema union, List options) { - Preconditions.checkArgument(AvroSchemaUtil.isOptionSchema(union), - "Unsupported type: non-option union: %s", union); - // records, arrays, and maps will check nullability later - if (options.get(0) == null) { - return options.get(1); + if (AvroSchemaUtil.isOptionSchema(union)) { + // Optional simple union + // records, arrays, and maps will check nullability later + if (options.get(0) == null) { + return options.get(1); + } else { + return options.get(0); + } } else { - return options.get(0); + // Complex union + List newFields = Lists.newArrayListWithExpectedSize(options.size()); + + int tagIndex = 0; + for (Type type : options) { + if (type != null) { + newFields.add(Types.NestedField.optional(allocateId(), "tag_" + tagIndex++, type)); + } + } + + return Types.StructType.of(newFields); } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java b/core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java new file mode 100644 index 0000000000..0cc2c58ddf --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.avro; + +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.junit.Assert; +import org.junit.Test; + + +public class TestAvroComplexUnion { + + @Test + public void testRequiredComplexUnion() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + String expectedIcebergSchema = "table {\n" + + " 0: unionCol: required struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}"; + + Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); + } + + @Test + public void testOptionalComplexUnion() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + String expectedIcebergSchema = + "table {\n" + " 0: unionCol: optional struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}"; + + Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); + } + + @Test + public void testSingleComponentUnion() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + String expectedIcebergSchema = "table {\n" + " 0: unionCol: required struct<1: tag_0: optional int>\n" + "}"; + + Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); + } + + @Test + public void testOptionSchema() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("optionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .nullDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + String expectedIcebergSchema = "table {\n" + " 0: optionCol: optional int\n" + "}"; + + Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); + } + + @Test + public void testNullUnionSchema() { + Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("nullUnionCol") + .type() + .unionOf() + .nullType() + .endUnion() + .noDefault() + .endRecord(); + + org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); + String expectedIcebergSchema = "table {\n" + " 0: nullUnionCol: optional struct<>\n" + "}"; + + Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java index 46c594e56a..88888f436e 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -27,6 +27,7 @@ import org.apache.avro.Schema; import org.apache.avro.io.DatumReader; import org.apache.avro.io.Decoder; +import org.apache.iceberg.avro.AvroSchemaUtil; import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor; import org.apache.iceberg.avro.ValueReader; import org.apache.iceberg.avro.ValueReaders; @@ -79,7 +80,11 @@ public ValueReader record(Types.StructType expected, Schema record, List union(Type expected, Schema union, List> options) { - return ValueReaders.union(options); + if (AvroSchemaUtil.isOptionSchema(union)) { + return ValueReaders.union(options); + } else { + return SparkValueReaders.union(options); + } } @Override diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index 0d3ce2b28d..6f9171ec8f 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -81,6 +81,10 @@ static ValueReader struct(List> readers, Types.Struc return new StructReader(readers, struct, idToConstant); } + static ValueReader union(List> readers) { + return new UnionReader(readers); + } + private static class StringReader implements ValueReader { private static final StringReader INSTANCE = new StringReader(); @@ -285,4 +289,29 @@ protected void set(InternalRow struct, int pos, Object value) { } } } + + static class UnionReader implements ValueReader { + private final ValueReader[] readers; + + private UnionReader(List> readers) { + this.readers = new ValueReader[readers.size()]; + for (int i = 0; i < this.readers.length; i += 1) { + this.readers[i] = readers.get(i); + } + } + + @Override + public InternalRow read(Decoder decoder, Object reuse) throws IOException { + InternalRow struct = new GenericInternalRow(readers.length); + int index = decoder.readIndex(); + Object value = this.readers[index].read(decoder, reuse); + + for (int i = 0; i < readers.length; i += 1) { + struct.setNullAt(i); + } + struct.update(index, value); + + return struct; + } + } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java new file mode 100644 index 0000000000..6ed4201be1 --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data; + +import java.io.File; +import java.io.IOException; +import java.util.List; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.avro.AvroIterable; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + + +public class TestSparkAvroUnions { + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void writeAndValidateRequiredComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "StringType1"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + } + } + + @Test + public void writeAndValidateOptionalComplexUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", "StringType1"); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + } + } + + @Test + public void writeAndValidateSingleComponentUnion() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("unionCol") + .type() + .unionOf() + .intType() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", 1); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 2); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + } + } +}