diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java index 9bbcf9d770..e6515f1180 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java @@ -61,7 +61,7 @@ public static T visit(Schema schema, AvroSchemaVisitor visitor) { int idx = 0; for (Schema type : types) { if (type.getType() != Schema.Type.NULL) { - options.add(visitWithName("tag_" + idx, type, visitor)); + options.add(visitWithName("field" + idx, type, visitor)); idx += 1; } else { options.add(visit(type, visitor)); diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java index 4c8aff64be..606529685f 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaWithTypeVisitor.java @@ -80,17 +80,24 @@ private static T visitUnion(Type type, Schema union, AvroSchemaWithTypeVisit List types = union.getTypes(); List options = Lists.newArrayListWithExpectedSize(types.size()); - int index = 0; - for (Schema branch : types) { - if (branch.getType() == Schema.Type.NULL) { - options.add(visit((Type) null, branch, visitor)); - } else { - if (AvroSchemaUtil.isOptionSchema(union)) { + // simple union case + if (AvroSchemaUtil.isOptionSchema(union)) { + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); + } else { options.add(visit(type, branch, visitor)); + } + } + } else { // complex union case + int index = 1; + for (Schema branch : types) { + if (branch.getType() == Schema.Type.NULL) { + options.add(visit((Type) null, branch, visitor)); } else { options.add(visit(type.asStructType().fields().get(index).type(), branch, visitor)); + index += 1; } - index++; } } return visitor.union(type, union, options); diff --git a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java index cf8cd4ecdf..285a520753 100644 --- a/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java +++ b/core/src/main/java/org/apache/iceberg/avro/SchemaToType.java @@ -19,6 +19,7 @@ package org.apache.iceberg.avro; +import java.util.ArrayList; import java.util.List; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; @@ -116,12 +117,13 @@ public Type union(Schema union, List options) { } } else { // Complex union - List newFields = Lists.newArrayListWithExpectedSize(options.size()); + List newFields = new ArrayList<>(); + newFields.add(Types.NestedField.required(allocateId(), "tag", Types.IntegerType.get())); int tagIndex = 0; for (Type type : options) { if (type != null) { - newFields.add(Types.NestedField.optional(allocateId(), "tag_" + tagIndex++, type)); + newFields.add(Types.NestedField.optional(allocateId(), "field" + tagIndex++, type)); } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java similarity index 64% rename from core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java rename to core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java index 0cc2c58ddf..1ba6471735 100644 --- a/core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java +++ b/core/src/test/java/org/apache/iceberg/avro/TestUnionSchemaConversions.java @@ -25,7 +25,7 @@ import org.junit.Test; -public class TestAvroComplexUnion { +public class TestUnionSchemaConversions { @Test public void testRequiredComplexUnion() { @@ -43,7 +43,8 @@ public void testRequiredComplexUnion() { org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); String expectedIcebergSchema = "table {\n" + - " 0: unionCol: required struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}"; + " 0: unionCol: required struct<1: tag: required int, 2: field0: optional int, 3: field1: optional string>\n" + + "}"; Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); } @@ -65,32 +66,15 @@ public void testOptionalComplexUnion() { .endRecord(); org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - String expectedIcebergSchema = - "table {\n" + " 0: unionCol: optional struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}"; - - Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); - } - - @Test - public void testSingleComponentUnion() { - Schema avroSchema = SchemaBuilder.record("root") - .fields() - .name("unionCol") - .type() - .unionOf() - .intType() - .endUnion() - .noDefault() - .endRecord(); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - String expectedIcebergSchema = "table {\n" + " 0: unionCol: required struct<1: tag_0: optional int>\n" + "}"; + String expectedIcebergSchema = "table {\n" + + " 0: unionCol: optional struct<1: tag: required int, 2: field0: optional int, 3: field1: optional string>\n" + + "}"; Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); } @Test - public void testOptionSchema() { + public void testSimpleUnionSchema() { Schema avroSchema = SchemaBuilder.record("root") .fields() .name("optionCol") @@ -108,22 +92,4 @@ public void testOptionSchema() { Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); } - - @Test - public void testNullUnionSchema() { - Schema avroSchema = SchemaBuilder.record("root") - .fields() - .name("nullUnionCol") - .type() - .unionOf() - .nullType() - .endUnion() - .noDefault() - .endRecord(); - - org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema); - String expectedIcebergSchema = "table {\n" + " 0: nullUnionCol: optional struct<>\n" + "}"; - - Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString()); - } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java index c222cd2530..044b4fb5eb 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkValueReaders.java @@ -292,7 +292,7 @@ protected void set(InternalRow struct, int pos, Object value) { } } - static class UnionReader implements ValueReader { + private static class UnionReader implements ValueReader { private final Schema schema; private final ValueReader[] readers; @@ -316,20 +316,30 @@ public InternalRow read(Decoder decoder, Object reuse) throws IOException { break; } } - InternalRow struct = new GenericInternalRow(nullIndex >= 0 ? alts.size() - 1 : alts.size()); + + int index = decoder.readIndex(); + if (index == nullIndex) { + // if it is a null data, directly return null as the whole union result + return null; + } + + // otherwise, we need to return an InternalRow as a struct data + InternalRow struct = new GenericInternalRow(nullIndex >= 0 ? alts.size() : alts.size() + 1); for (int i = 0; i < struct.numFields(); i += 1) { struct.setNullAt(i); } - int index = decoder.readIndex(); - Object value = this.readers[index].read(decoder, reuse); + Object value = readers[index].read(decoder, reuse); if (nullIndex < 0) { - struct.update(index, value); + struct.update(index + 1, value); + struct.setInt(0, index); } else if (index < nullIndex) { + struct.update(index + 1, value); + struct.setInt(0, index); + } else { struct.update(index, value); - } else if (index > nullIndex) { - struct.update(index - 1, value); + struct.setInt(0, index - 1); } return struct; diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java index 8f38ff93bc..636a4e6baf 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkAvroUnions.java @@ -65,8 +65,6 @@ public void writeAndValidateRequiredComplexUnion() throws IOException { unionRecord2.put("unionCol", 1); File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { writer.create(avroSchema, testFile); writer.append(unionRecord1); @@ -82,13 +80,15 @@ public void writeAndValidateRequiredComplexUnion() throws IOException { .build()) { rows = Lists.newArrayList(reader); - Assert.assertEquals(2, rows.get(0).getStruct(0, 2).numFields()); - Assert.assertTrue(rows.get(0).getStruct(0, 2).isNullAt(0)); - Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1)); + Assert.assertEquals(3, rows.get(0).getStruct(0, 3).numFields()); + Assert.assertEquals(1, rows.get(0).getStruct(0, 3).getInt(0)); + Assert.assertTrue(rows.get(0).getStruct(0, 3).isNullAt(1)); + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); - Assert.assertEquals(2, rows.get(1).getStruct(0, 2).numFields()); - Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0)); - Assert.assertTrue(rows.get(1).getStruct(0, 2).isNullAt(1)); + Assert.assertEquals(3, rows.get(1).getStruct(0, 3).numFields()); + Assert.assertEquals(0, rows.get(1).getStruct(0, 3).getInt(0)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(1).getStruct(0, 3).isNullAt(2)); } } @@ -116,8 +116,6 @@ public void writeAndValidateOptionalComplexUnion() throws IOException { unionRecord3.put("unionCol", null); File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { writer.create(avroSchema, testFile); writer.append(unionRecord1); @@ -134,10 +132,9 @@ public void writeAndValidateOptionalComplexUnion() throws IOException { .build()) { rows = Lists.newArrayList(reader); - Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1)); - Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0)); - Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(0)); - Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(1)); + Assert.assertEquals("foo", rows.get(0).getStruct(0, 3).getString(2)); + Assert.assertEquals(1, rows.get(1).getStruct(0, 3).getInt(1)); + Assert.assertTrue(rows.get(2).isNullAt(0)); } } @@ -161,8 +158,6 @@ public void writeAndValidateSingleTypeUnion() throws IOException { unionRecord2.put("unionCol", 1); File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { writer.create(avroSchema, testFile); writer.append(unionRecord1); @@ -207,8 +202,6 @@ public void testDeeplyNestedUnionSchema1() throws IOException { unionRecord2.put("col1", Arrays.asList(2, "bar")); File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { writer.create(avroSchema, testFile); writer.append(unionRecord1); @@ -225,7 +218,7 @@ public void testDeeplyNestedUnionSchema1() throws IOException { rows = Lists.newArrayList(reader); // making sure it reads the correctly nested structured data, based on the transformation from union to struct - Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 2).getString(1)); + Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 3).getString(2)); } } @@ -265,8 +258,6 @@ public void testDeeplyNestedUnionSchema2() throws IOException { outer.put("col1", Arrays.asList(inner)); File testFile = temp.newFile(); - Assert.assertTrue("Delete should succeed", testFile.delete()); - try (DataFileWriter writer = new DataFileWriter<>(new GenericDatumWriter<>())) { writer.create(avroSchema, testFile); writer.append(outer); @@ -281,7 +272,7 @@ public void testDeeplyNestedUnionSchema2() throws IOException { rows = Lists.newArrayList(reader); // making sure it reads the correctly nested structured data, based on the transformation from union to struct - Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 2).getStruct(0, 1).getInt(0)); + Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 3).getStruct(1, 1).getInt(0)); } } }