diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java index 281d45b51c..9bbcf9d770 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java @@ -52,8 +52,21 @@ public static T visit(Schema schema, AvroSchemaVisitor visitor) { case UNION: List types = schema.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (Schema type : types) { - options.add(visit(type, visitor)); + if (AvroSchemaUtil.isOptionSchema(schema)) { + for (Schema type : types) { + options.add(visit(type, visitor)); + } + } else { + // complex union case + int idx = 0; + for (Schema type : types) { + if (type.getType() != Schema.Type.NULL) { + options.add(visitWithName("tag_" + idx, type, visitor)); + idx += 1; + } else { + options.add(visit(type, visitor)); + } + } } return visitor.union(schema, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 097160eba3..f524f8b68b 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -125,10 +125,11 @@ public Schema union(Schema union, List options) { return null; } else { // Complex union case - return union; + return copyUnion(union, options); } } + @Override @SuppressWarnings("checkstyle:CyclomaticComplexity") public Schema array(Schema array, Schema element) { @@ -297,4 +298,19 @@ private static Schema.Field copyField(Schema.Field field, Schema newSchema, Inte private static boolean isOptionSchemaWithNonNullFirstOption(Schema schema) { return AvroSchemaUtil.isOptionSchema(schema) && schema.getTypes().get(0).getType() != Schema.Type.NULL; } + + // for primitive types, the visitResult will be null, we want to reuse the primitive types from the original + // schema, while for nested types, we want to use the visitResult because they have content from the previous + // recursive calls. + private static Schema copyUnion(Schema record, List visitResults) { + List alts = Lists.newArrayListWithExpectedSize(visitResults.size()); + for (int i = 0; i < visitResults.size(); i++) { + if (visitResults.get(i) == null) { + alts.add(record.getTypes().get(i)); + } else { + alts.add(visitResults.get(i)); + } + } + return Schema.createUnion(alts); + } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java index 88888f436e..f467650261 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkAvroReader.java @@ -83,7 +83,7 @@ public ValueReader union(Type expected, Schema union, List> op if (AvroSchemaUtil.isOptionSchema(union)) { return ValueReaders.union(options); } else { - return SparkValueReaders.union(options); + return SparkValueReaders.union(union, options); } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index 6f9171ec8f..c222cd2530 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -27,6 +27,8 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import java.util.Objects; +import org.apache.avro.Schema; import org.apache.avro.io.Decoder; import org.apache.avro.util.Utf8; import org.apache.iceberg.avro.ValueReader; @@ -81,8 +83,8 @@ static ValueReader struct(List> readers, Types.Struc return new StructReader(readers, struct, idToConstant); } - static ValueReader union(List> readers) { - return new UnionReader(readers); + static ValueReader union(Schema schema, List> readers) { + return new UnionReader(schema, readers); } private static class StringReader implements ValueReader { @@ -291,9 +293,11 @@ protected void set(InternalRow struct, int pos, Object value) { } static class UnionReader implements ValueReader { + private final Schema schema; private final ValueReader[] readers; - private UnionReader(List> readers) { + private UnionReader(Schema schema, List> readers) { + this.schema = schema; this.readers = new ValueReader[readers.size()]; for (int i = 0; i < this.readers.length; i += 1) { this.readers[i] = readers.get(i); @@ -302,14 +306,31 @@ private UnionReader(List> readers) { @Override public InternalRow read(Decoder decoder, Object reuse) throws IOException { - InternalRow struct = new GenericInternalRow(readers.length); + // first we need to filter out NULL alternative if it exists in the union schema + int nullIndex = -1; + List alts = schema.getTypes(); + for (int i = 0; i < alts.size(); i++) { + Schema alt = alts.get(i); + if (Objects.equals(alt.getType(), Schema.Type.NULL)) { + nullIndex = i; + break; + } + } + InternalRow struct = new GenericInternalRow(nullIndex >= 0 ? alts.size() - 1 : alts.size()); + for (int i = 0; i < struct.numFields(); i += 1) { + struct.setNullAt(i); + } + int index = decoder.readIndex(); Object value = this.readers[index].read(decoder, reuse); - for (int i = 0; i < readers.length; i += 1) { - struct.setNullAt(i); + if (nullIndex < 0) { + struct.update(index, value); + } else if (index < nullIndex) { + struct.update(index, value); + } else if (index > nullIndex) { + struct.update(index - 1, value); } - struct.update(index, value); return struct; } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java index 6ed4201be1..8f38ff93bc 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java @@ -21,6 +21,7 @@ import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.List; import org.apache.avro.SchemaBuilder; import org.apache.avro.file.DataFileWriter; @@ -59,7 +60,7 @@ public void writeAndValidateRequiredComplexUnion() throws IOException { .endRecord(); GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); - unionRecord1.put("unionCol", "StringType1"); + unionRecord1.put("unionCol", "foo"); GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); unionRecord2.put("unionCol", 1); @@ -80,6 +81,14 @@ public void writeAndValidateRequiredComplexUnion() throws IOException { .project(expectedSchema) .build()) { rows = Lists.newArrayList(reader); + + Assert.assertEquals(2, rows.get(0).getStruct(0, 2).numFields()); + Assert.assertTrue(rows.get(0).getStruct(0, 2).isNullAt(0)); + Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1)); + + Assert.assertEquals(2, rows.get(1).getStruct(0, 2).numFields()); + Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0)); + Assert.assertTrue(rows.get(1).getStruct(0, 2).isNullAt(1)); } } @@ -96,13 +105,15 @@ public void writeAndValidateOptionalComplexUnion() throws IOException { .and() .stringType() .endUnion() - .noDefault() + .nullDefault() .endRecord(); GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); - unionRecord1.put("unionCol", "StringType1"); + unionRecord1.put("unionCol", "foo"); GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); unionRecord2.put("unionCol", 1); + GenericData.Record unionRecord3 = new GenericData.Record(avroSchema); + unionRecord3.put("unionCol", null); File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); @@ -111,6 +122,7 @@ public void writeAndValidateOptionalComplexUnion() throws IOException { writer.create(avroSchema, testFile); writer.append(unionRecord1); writer.append(unionRecord2); + writer.append(unionRecord3); } Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); @@ -121,25 +133,78 @@ public void writeAndValidateOptionalComplexUnion() throws IOException { .project(expectedSchema) .build()) { rows = Lists.newArrayList(reader); + + Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0)); + Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(0)); + Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(1)); } } @Test - public void writeAndValidateSingleComponentUnion() throws IOException { + public void writeAndValidateSingleTypeUnion() throws IOException { org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") .fields() .name("unionCol") .type() .unionOf() + .nullType() + .and() .intType() .endUnion() + .nullDefault() + .endRecord(); + + GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); + unionRecord1.put("unionCol", 0); + GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); + unionRecord2.put("unionCol", 1); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(unionRecord1); + writer.append(unionRecord2); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + Assert.assertEquals(0, rows.get(0).getInt(0)); + Assert.assertEquals(1, rows.get(1).getInt(0)); + } + } + + @Test + public void testDeeplyNestedUnionSchema1() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .nullType() + .and() + .intType() + .and() + .stringType() + .endUnion() .noDefault() .endRecord(); GenericData.Record unionRecord1 = new GenericData.Record(avroSchema); - unionRecord1.put("unionCol", 1); + unionRecord1.put("col1", Arrays.asList("foo", 1)); GenericData.Record unionRecord2 = new GenericData.Record(avroSchema); - unionRecord2.put("unionCol", 2); + unionRecord2.put("col1", Arrays.asList(2, "bar")); File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); @@ -158,6 +223,65 @@ public void writeAndValidateSingleComponentUnion() throws IOException { .project(expectedSchema) .build()) { rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from union to struct + Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 2).getString(1)); + } + } + + @Test + public void testDeeplyNestedUnionSchema2() throws IOException { + org.apache.avro.Schema avroSchema = SchemaBuilder.record("root") + .fields() + .name("col1") + .type() + .array() + .items() + .unionOf() + .record("r1") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .and() + .record("r2") + .fields() + .name("id") + .type() + .intType() + .noDefault() + .endRecord() + .endUnion() + .noDefault() + .endRecord(); + + GenericData.Record outer = new GenericData.Record(avroSchema); + GenericData.Record inner = new GenericData.Record(avroSchema.getFields().get(0).schema() + .getElementType().getTypes().get(0)); + + inner.put("id", 1); + outer.put("col1", Arrays.asList(inner)); + + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { + writer.create(avroSchema, testFile); + writer.append(outer); + } + + Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema); + List rows; + try (AvroIterable reader = Avro.read(Files.localInput(testFile)) + .createReaderFunc(SparkAvroReader::new) + .project(expectedSchema) + .build()) { + rows = Lists.newArrayList(reader); + + // making sure it reads the correctly nested structured data, based on the transformation from union to struct + Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 2).getStruct(0, 1).getInt(0)); } } }