Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions core/src/main/java/org/apache/iceberg/avro/AvroSchemaUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,27 @@ public static boolean isOptionSchema(Schema schema) {
return false;
}

/**
* This method decides whether a schema is of type union and is complex union and is optional
*
* Complex union: the number of options in union not equals to 2
* Optional: null is present in union
*
* @param schema input schema
* @return true if schema is complex union and it is optional
*/
public static boolean isOptionalComplexUnion(Schema schema) {
if (schema.getType() == UNION && schema.getTypes().size() != 2) {
for (Schema type : schema.getTypes()) {
if (type.getType() == Schema.Type.NULL) {
return true;
}
}
}

return false;
}

public static Schema toOption(Schema schema) {
if (schema.getType() == UNION) {
Preconditions.checkArgument(isOptionSchema(schema),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ private static <T> T visitRecord(Types.StructType struct, Schema record, AvroSch
private static <T> T visitUnion(Type type, Schema union, AvroSchemaWithTypeVisitor<T> visitor) {
List<Schema> types = union.getTypes();
List<T> options = Lists.newArrayListWithExpectedSize(types.size());

int index = 0;
for (Schema branch : types) {
if (branch.getType() == Schema.Type.NULL) {
options.add(visit((Type) null, branch, visitor));
} else {
options.add(visit(type, branch, visitor));
if (AvroSchemaUtil.isOptionSchema(union)) {
options.add(visit(type, branch, visitor));
} else {
options.add(visit(type.asStructType().fields().get(index).type(), branch, visitor));
}
index++;
}
}
return visitor.union(type, union, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ public Schema.Field field(Schema.Field field, Supplier<Schema> fieldResult) {

@Override
public Schema union(Schema union, Iterable<Schema> options) {
Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union),
"Invalid schema: non-option unions are not supported: %s", union);
Schema nonNullOriginal = AvroSchemaUtil.fromOption(union);
Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options));
if (AvroSchemaUtil.isOptionSchema(union)) {
Schema nonNullOriginal = AvroSchemaUtil.fromOption(union);
Schema nonNullResult = AvroSchemaUtil.fromOptions(Lists.newArrayList(options));

if (nonNullOriginal != nonNullResult) {
return AvroSchemaUtil.toOption(nonNullResult);
if (nonNullOriginal != nonNullResult) {
return AvroSchemaUtil.toOption(nonNullResult);
}
}

return union;
Expand Down
32 changes: 17 additions & 15 deletions core/src/main/java/org/apache/iceberg/avro/PruneColumns.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,25 +106,27 @@ public Schema record(Schema record, List<String> names, List<Schema> fields) {

@Override
public Schema union(Schema union, List<Schema> options) {
Preconditions.checkState(AvroSchemaUtil.isOptionSchema(union),
"Invalid schema: non-option unions are not supported: %s", union);

// only unions with null are allowed, and a null schema results in null
Schema pruned = null;
if (options.get(0) != null) {
pruned = options.get(0);
} else if (options.get(1) != null) {
pruned = options.get(1);
}
if (AvroSchemaUtil.isOptionSchema(union)) {
// case option union
Schema pruned = null;
if (options.get(0) != null) {
pruned = options.get(0);
} else if (options.get(1) != null) {
pruned = options.get(1);
}

if (pruned != null) {
if (pruned != AvroSchemaUtil.fromOption(union)) {
return AvroSchemaUtil.toOption(pruned);
if (pruned != null) {
if (pruned != AvroSchemaUtil.fromOption(union)) {
return AvroSchemaUtil.toOption(pruned);
}
return union;
}

return null;
} else {
// Complex union case
return union;
}

return null;
}

@Override
Expand Down
27 changes: 20 additions & 7 deletions core/src/main/java/org/apache/iceberg/avro/SchemaToType.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public Type record(Schema record, List<String> names, List<Type> fieldTypes) {
Type fieldType = fieldTypes.get(i);
int fieldId = getId(field);

if (AvroSchemaUtil.isOptionSchema(field.schema())) {
if (AvroSchemaUtil.isOptionSchema(field.schema()) || AvroSchemaUtil.isOptionalComplexUnion(field.schema())) {
newFields.add(Types.NestedField.optional(fieldId, field.name(), fieldType, field.doc()));
} else {
newFields.add(Types.NestedField.required(fieldId, field.name(), fieldType, field.doc()));
Expand All @@ -104,13 +104,26 @@ public Type record(Schema record, List<String> names, List<Type> fieldTypes) {

@Override
public Type union(Schema union, List<Type> options) {
Preconditions.checkArgument(AvroSchemaUtil.isOptionSchema(union),
"Unsupported type: non-option union: %s", union);
// records, arrays, and maps will check nullability later
if (options.get(0) == null) {
return options.get(1);
if (AvroSchemaUtil.isOptionSchema(union)) {
// Optional simple union
// records, arrays, and maps will check nullability later
if (options.get(0) == null) {
return options.get(1);
} else {
return options.get(0);
}
} else {
return options.get(0);
// Complex union
List<Types.NestedField> newFields = Lists.newArrayListWithExpectedSize(options.size());

int tagIndex = 0;
for (Type type : options) {
if (type != null) {
newFields.add(Types.NestedField.optional(allocateId(), "tag_" + tagIndex++, type));
}
}

return Types.StructType.of(newFields);
}
}

Expand Down
129 changes: 129 additions & 0 deletions core/src/test/java/org/apache/iceberg/avro/TestAvroComplexUnion.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iceberg.avro;

import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.junit.Assert;
import org.junit.Test;


public class TestAvroComplexUnion {

@Test
public void testRequiredComplexUnion() {
Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("unionCol")
.type()
.unionOf()
.intType()
.and()
.stringType()
.endUnion()
.noDefault()
.endRecord();

org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema);
String expectedIcebergSchema = "table {\n" +
" 0: unionCol: required struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}";

Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString());
}

@Test
public void testOptionalComplexUnion() {
Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("unionCol")
.type()
.unionOf()
.nullType()
.and()
.intType()
.and()
.stringType()
.endUnion()
.noDefault()
.endRecord();

org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema);
String expectedIcebergSchema =
"table {\n" + " 0: unionCol: optional struct<1: tag_0: optional int, 2: tag_1: optional string>\n" + "}";

Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString());
}

@Test
public void testSingleComponentUnion() {
Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("unionCol")
.type()
.unionOf()
.intType()
.endUnion()
.noDefault()
.endRecord();

org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema);
String expectedIcebergSchema = "table {\n" + " 0: unionCol: required struct<1: tag_0: optional int>\n" + "}";

Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString());
}

@Test
public void testOptionSchema() {
Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("optionCol")
.type()
.unionOf()
.nullType()
.and()
.intType()
.endUnion()
.nullDefault()
.endRecord();

org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema);
String expectedIcebergSchema = "table {\n" + " 0: optionCol: optional int\n" + "}";

Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString());
}

@Test
public void testNullUnionSchema() {
Schema avroSchema = SchemaBuilder.record("root")
.fields()
.name("nullUnionCol")
.type()
.unionOf()
.nullType()
.endUnion()
.noDefault()
.endRecord();

org.apache.iceberg.Schema icebergSchema = AvroSchemaUtil.toIceberg(avroSchema);
String expectedIcebergSchema = "table {\n" + " 0: nullUnionCol: optional struct<>\n" + "}";

Assert.assertEquals(expectedIcebergSchema, icebergSchema.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.avro.Schema;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.Decoder;
import org.apache.iceberg.avro.AvroSchemaUtil;
import org.apache.iceberg.avro.AvroSchemaWithTypeVisitor;
import org.apache.iceberg.avro.ValueReader;
import org.apache.iceberg.avro.ValueReaders;
Expand Down Expand Up @@ -79,7 +80,11 @@ public ValueReader<?> record(Types.StructType expected, Schema record, List<Stri

@Override
public ValueReader<?> union(Type expected, Schema union, List<ValueReader<?>> options) {
return ValueReaders.union(options);
if (AvroSchemaUtil.isOptionSchema(union)) {
return ValueReaders.union(options);
} else {
return SparkValueReaders.union(options);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ static ValueReader<InternalRow> struct(List<ValueReader<?>> readers, Types.Struc
return new StructReader(readers, struct, idToConstant);
}

static ValueReader<InternalRow> union(List<ValueReader<?>> readers) {
return new UnionReader(readers);
}

private static class StringReader implements ValueReader<UTF8String> {
private static final StringReader INSTANCE = new StringReader();

Expand Down Expand Up @@ -285,4 +289,29 @@ protected void set(InternalRow struct, int pos, Object value) {
}
}
}

static class UnionReader implements ValueReader<InternalRow> {
private final ValueReader[] readers;

private UnionReader(List<ValueReader<?>> readers) {
this.readers = new ValueReader[readers.size()];
for (int i = 0; i < this.readers.length; i += 1) {
this.readers[i] = readers.get(i);
}
}

@Override
public InternalRow read(Decoder decoder, Object reuse) throws IOException {
InternalRow struct = new GenericInternalRow(readers.length);
int index = decoder.readIndex();
Object value = this.readers[index].read(decoder, reuse);

for (int i = 0; i < readers.length; i += 1) {
struct.setNullAt(i);
}
struct.update(index, value);

return struct;
}
}
}
Loading