Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,21 @@ public static <T> T visit(Schema schema, AvroSchemaVisitor<T> visitor) {
case UNION:
List<Schema> types = schema.getTypes();
List<T> options = Lists.newArrayListWithExpectedSize(types.size());
for (Schema type : types) {
options.add(visit(type, visitor));
if (AvroSchemaUtil.isOptionSchema(schema)) {
for (Schema type : types) {
options.add(visit(type, visitor));
}
} else {
// complex union case
int idx = 0;
for (Schema type : types) {
if (type.getType() != Schema.Type.NULL) {
options.add(visitWithName("tag_" + idx, type, visitor));
idx += 1;
} else {
options.add(visit(type, visitor));
}
}
}
return visitor.union(schema, options);

Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/org/apache/iceberg/avro/PruneColumns.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ public Schema union(Schema union, List<Schema> options) {
return null;
} else {
// Complex union case
return union;
return copyUnion(union, options);
}
}


@Override
@SuppressWarnings("checkstyle:CyclomaticComplexity")
public Schema array(Schema array, Schema element) {
Expand Down Expand Up @@ -297,4 +298,19 @@ private static Schema.Field copyField(Schema.Field field, Schema newSchema, Inte
private static boolean isOptionSchemaWithNonNullFirstOption(Schema schema) {
return AvroSchemaUtil.isOptionSchema(schema) && schema.getTypes().get(0).getType() != Schema.Type.NULL;
}

// for primitive types, the visitResult will be null, we want to reuse the primitive types from the original
// schema, while for nested types, we want to use the visitResult because they have content from the previous
// recursive calls.
private static Schema copyUnion(Schema record, List<Schema> visitResults) {
List<Schema> alts = Lists.newArrayListWithExpectedSize(visitResults.size());
for (int i = 0; i < visitResults.size(); i++) {
if (visitResults.get(i) == null) {
alts.add(record.getTypes().get(i));
} else {
alts.add(visitResults.get(i));
}
}
return Schema.createUnion(alts);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public ValueReader<?> union(Type expected, Schema union, List<ValueReader<?>> op
if (AvroSchemaUtil.isOptionSchema(union)) {
return ValueReaders.union(options);
} else {
return SparkValueReaders.union(options);
return SparkValueReaders.union(union, options);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.avro.Schema;
import org.apache.avro.io.Decoder;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.avro.ValueReader;
Expand Down Expand Up @@ -81,8 +83,8 @@ static ValueReader<InternalRow> struct(List<ValueReader<?>> readers, Types.Struc
return new StructReader(readers, struct, idToConstant);
}

static ValueReader<InternalRow> union(List<ValueReader<?>> readers) {
return new UnionReader(readers);
static ValueReader<InternalRow> union(Schema schema, List<ValueReader<?>> readers) {
return new UnionReader(schema, readers);
}

private static class StringReader implements ValueReader<UTF8String> {
Expand Down Expand Up @@ -291,9 +293,11 @@ protected void set(InternalRow struct, int pos, Object value) {
}

static class UnionReader implements ValueReader<InternalRow> {
private final Schema schema;
private final ValueReader[] readers;

private UnionReader(List<ValueReader<?>> readers) {
private UnionReader(Schema schema, List<ValueReader<?>> readers) {
this.schema = schema;
this.readers = new ValueReader[readers.size()];
for (int i = 0; i < this.readers.length; i += 1) {
this.readers[i] = readers.get(i);
Expand All @@ -302,14 +306,31 @@ private UnionReader(List<ValueReader<?>> readers) {

@Override
public InternalRow read(Decoder decoder, Object reuse) throws IOException {
InternalRow struct = new GenericInternalRow(readers.length);
// first we need to filter out NULL alternative if it exists in the union schema
int nullIndex = -1;
List<Schema> alts = schema.getTypes();
for (int i = 0; i < alts.size(); i++) {
Schema alt = alts.get(i);
if (Objects.equals(alt.getType(), Schema.Type.NULL)) {
nullIndex = i;
break;
}
}
InternalRow struct = new GenericInternalRow(nullIndex >= 0 ? alts.size() - 1 : alts.size());
for (int i = 0; i < struct.numFields(); i += 1) {
struct.setNullAt(i);
}

int index = decoder.readIndex();
Object value = this.readers[index].read(decoder, reuse);

for (int i = 0; i < readers.length; i += 1) {
struct.setNullAt(i);
if (nullIndex < 0) {
struct.update(index, value);
} else if (index < nullIndex) {
struct.update(index, value);
} else if (index > nullIndex) {
struct.update(index - 1, value);
}
struct.update(index, value);

return struct;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.avro.SchemaBuilder;
import org.apache.avro.file.DataFileWriter;
Expand Down Expand Up @@ -59,7 +60,7 @@ public void writeAndValidateRequiredComplexUnion() throws IOException {
.endRecord();

GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
unionRecord1.put("unionCol", "StringType1");
unionRecord1.put("unionCol", "foo");
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
unionRecord2.put("unionCol", 1);

Expand All @@ -80,6 +81,14 @@ public void writeAndValidateRequiredComplexUnion() throws IOException {
.project(expectedSchema)
.build()) {
rows = Lists.newArrayList(reader);

Assert.assertEquals(2, rows.get(0).getStruct(0, 2).numFields());
Assert.assertTrue(rows.get(0).getStruct(0, 2).isNullAt(0));
Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1));

Assert.assertEquals(2, rows.get(1).getStruct(0, 2).numFields());
Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0));
Assert.assertTrue(rows.get(1).getStruct(0, 2).isNullAt(1));
}
}

Expand All @@ -96,13 +105,15 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
.and()
.stringType()
.endUnion()
.noDefault()
.nullDefault()
.endRecord();

GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
unionRecord1.put("unionCol", "StringType1");
unionRecord1.put("unionCol", "foo");
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
unionRecord2.put("unionCol", 1);
GenericData.Record unionRecord3 = new GenericData.Record(avroSchema);
unionRecord3.put("unionCol", null);

File testFile = temp.newFile();
Assert.assertTrue("Delete should succeed", testFile.delete());
Expand All @@ -111,6 +122,7 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
writer.create(avroSchema, testFile);
writer.append(unionRecord1);
writer.append(unionRecord2);
writer.append(unionRecord3);
}

Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);
Expand All @@ -121,25 +133,78 @@ public void writeAndValidateOptionalComplexUnion() throws IOException {
.project(expectedSchema)
.build()) {
rows = Lists.newArrayList(reader);

Assert.assertEquals("foo", rows.get(0).getStruct(0, 2).getString(1));
Assert.assertEquals(1, rows.get(1).getStruct(0, 2).getInt(0));
Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(0));
Assert.assertTrue(rows.get(2).getStruct(0, 2).isNullAt(1));
}
}

@Test
public void writeAndValidateSingleComponentUnion() throws IOException {
public void writeAndValidateSingleTypeUnion() throws IOException {
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("unionCol")
.type()
.unionOf()
.nullType()
.and()
.intType()
.endUnion()
.nullDefault()
.endRecord();

GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
unionRecord1.put("unionCol", 0);
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
unionRecord2.put("unionCol", 1);

File testFile = temp.newFile();
Assert.assertTrue("Delete should succeed", testFile.delete());

try (DataFileWriter<GenericData.Record> writer = new DataFileWriter<>(new GenericDatumWriter<>())) {
writer.create(avroSchema, testFile);
writer.append(unionRecord1);
writer.append(unionRecord2);
}

Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);

List<InternalRow> rows;
try (AvroIterable<InternalRow> reader = Avro.read(Files.localInput(testFile))
.createReaderFunc(SparkAvroReader::new)
.project(expectedSchema)
.build()) {
rows = Lists.newArrayList(reader);

Assert.assertEquals(0, rows.get(0).getInt(0));
Assert.assertEquals(1, rows.get(1).getInt(0));
}
}

@Test
public void testDeeplyNestedUnionSchema1() throws IOException {
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("col1")
.type()
.array()
.items()
.unionOf()
.nullType()
.and()
.intType()
.and()
.stringType()
.endUnion()
.noDefault()
.endRecord();

GenericData.Record unionRecord1 = new GenericData.Record(avroSchema);
unionRecord1.put("unionCol", 1);
unionRecord1.put("col1", Arrays.asList("foo", 1));
GenericData.Record unionRecord2 = new GenericData.Record(avroSchema);
unionRecord2.put("unionCol", 2);
unionRecord2.put("col1", Arrays.asList(2, "bar"));

File testFile = temp.newFile();
Assert.assertTrue("Delete should succeed", testFile.delete());
Expand All @@ -158,6 +223,65 @@ public void writeAndValidateSingleComponentUnion() throws IOException {
.project(expectedSchema)
.build()) {
rows = Lists.newArrayList(reader);

// making sure it reads the correctly nested structured data, based on the transformation from union to struct
Assert.assertEquals("foo", rows.get(0).getArray(0).getStruct(0, 2).getString(1));
}
}

@Test
public void testDeeplyNestedUnionSchema2() throws IOException {
org.apache.avro.Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("col1")
.type()
.array()
.items()
.unionOf()
.record("r1")
.fields()
.name("id")
.type()
.intType()
.noDefault()
.endRecord()
.and()
.record("r2")
.fields()
.name("id")
.type()
.intType()
.noDefault()
.endRecord()
.endUnion()
.noDefault()
.endRecord();

GenericData.Record outer = new GenericData.Record(avroSchema);
GenericData.Record inner = new GenericData.Record(avroSchema.getFields().get(0).schema()
.getElementType().getTypes().get(0));

inner.put("id", 1);
outer.put("col1", Arrays.asList(inner));

File testFile = temp.newFile();
Assert.assertTrue("Delete should succeed", testFile.delete());

try (DataFileWriter<GenericData.Record> writer = new DataFileWriter<>(new GenericDatumWriter<>())) {
writer.create(avroSchema, testFile);
writer.append(outer);
}

Schema expectedSchema = AvroSchemaUtil.toIceberg(avroSchema);
List<InternalRow> rows;
try (AvroIterable<InternalRow> reader = Avro.read(Files.localInput(testFile))
.createReaderFunc(SparkAvroReader::new)
.project(expectedSchema)
.build()) {
rows = Lists.newArrayList(reader);

// making sure it reads the correctly nested structured data, based on the transformation from union to struct
Assert.assertEquals(1, rows.get(0).getArray(0).getStruct(0, 2).getStruct(0, 1).getInt(0));
}
}
}