diff --git a/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java b/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java index ad4c80fe43..db14c1f1c0 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java +++ b/orc/src/main/java/org/apache/iceberg/orc/ORCSchemaUtil.java @@ -263,36 +263,36 @@ public static TypeDescription buildOrcProjection(Schema schema, private static TypeDescription buildOrcProjection(Integer fieldId, Type type, boolean isRequired, Map mapping) { final TypeDescription orcType; + final OrcField orcField = mapping.getOrDefault(fieldId, null); switch (type.typeId()) { case STRUCT: orcType = buildOrcProjectForStructType(fieldId, type, isRequired, mapping); break; case LIST: - Types.ListType list = (Types.ListType) type; - TypeDescription elementType = buildOrcProjection(list.elementId(), list.elementType(), - isRequired && list.isElementRequired(), mapping); - orcType = TypeDescription.createList(elementType); + orcType = buildOrcProjectionForListType((Types.ListType) type, isRequired, mapping, orcField); break; case MAP: - Types.MapType map = (Types.MapType) type; - TypeDescription keyType = buildOrcProjection(map.keyId(), map.keyType(), isRequired, mapping); - TypeDescription valueType = buildOrcProjection(map.valueId(), map.valueType(), - isRequired && map.isValueRequired(), mapping); - orcType = TypeDescription.createMap(keyType, valueType); + orcType = buildOrcProjectionForMapType((Types.MapType) type, isRequired, mapping, orcField); break; default: if (mapping.containsKey(fieldId)) { TypeDescription originalType = mapping.get(fieldId).type(); - Optional promotedType = getPromotedType(type, originalType); - - if (promotedType.isPresent()) { - orcType = promotedType.get(); - } else { - Preconditions.checkArgument(isSameType(originalType, type), - "Can not promote %s type to %s", - originalType.getCategory(), type.typeId().name()); + if (originalType != null && originalType.getCategory().equals(TypeDescription.Category.UNION)) { + Preconditions.checkState(originalType.getChildren().size() == 1, + "Expect single type union for orc schema."); orcType = originalType.clone(); + } else { + Optional promotedType = getPromotedType(type, originalType); + + if (promotedType.isPresent()) { + orcType = promotedType.get(); + } else { + Preconditions.checkArgument(isSameType(originalType, type), + "Can not promote %s type to %s", + originalType.getCategory(), type.typeId().name()); + orcType = originalType.clone(); + } } } else { if (isRequired) { @@ -307,19 +307,58 @@ private static TypeDescription buildOrcProjection(Integer fieldId, Type type, bo return orcType; } + private static TypeDescription buildOrcProjectionForMapType(Types.MapType type, boolean isRequired, + Map mapping, OrcField orcField) { + final TypeDescription orcType; + if (orcField != null && orcField.type.getCategory().equals(TypeDescription.Category.UNION)) { + Preconditions.checkState(orcField.type.getChildren().size() == 1, + "Expect single type union for orc schema."); + + orcType = TypeDescription.createUnion(); + Types.MapType map = type; + TypeDescription keyType = buildOrcProjection(map.keyId(), map.keyType(), isRequired, mapping); + TypeDescription valueType = buildOrcProjection(map.valueId(), map.valueType(), + isRequired && map.isValueRequired(), mapping); + orcType.addUnionChild(TypeDescription.createMap(keyType, valueType)); + } else { + Types.MapType map = type; + TypeDescription keyType = buildOrcProjection(map.keyId(), map.keyType(), isRequired, mapping); + TypeDescription valueType = buildOrcProjection(map.valueId(), map.valueType(), + isRequired && map.isValueRequired(), mapping); + orcType = TypeDescription.createMap(keyType, valueType); + } + return orcType; + } + + private static TypeDescription buildOrcProjectionForListType(Types.ListType type, boolean isRequired, + Map mapping, OrcField orcField) { + final TypeDescription orcType; + if (orcField != null && orcField.type.getCategory().equals(TypeDescription.Category.UNION)) { + Preconditions.checkState(orcField.type.getChildren().size() == 1, + "Expect single type union for orc schema."); + + orcType = TypeDescription.createUnion(); + Types.ListType list = type; + TypeDescription elementType = buildOrcProjection(list.elementId(), list.elementType(), + isRequired && list.isElementRequired(), mapping); + orcType.addUnionChild(TypeDescription.createList(elementType)); + } else { + Types.ListType list = type; + TypeDescription elementType = buildOrcProjection(list.elementId(), list.elementType(), + isRequired && list.isElementRequired(), mapping); + orcType = TypeDescription.createList(elementType); + } + return orcType; + } + private static TypeDescription buildOrcProjectForStructType(Integer fieldId, Type type, boolean isRequired, Map mapping) { TypeDescription orcType; OrcField orcField = mapping.getOrDefault(fieldId, null); - // this branch means the iceberg struct schema actually correspond to an underlying union + if (orcField != null && orcField.type.getCategory().equals(TypeDescription.Category.UNION)) { - orcType = TypeDescription.createUnion(); - List nestedFields = type.asStructType().fields(); - for (Types.NestedField nestedField : nestedFields.subList(1, nestedFields.size())) { - TypeDescription childType = buildOrcProjection(nestedField.fieldId(), nestedField.type(), - isRequired && nestedField.isRequired(), mapping); - orcType.addUnionChild(childType); - } + // this branch means the iceberg struct schema actually correspond to an underlying union + orcType = getOrcSchemaForUnionType(type, isRequired, mapping, orcField); } else { orcType = TypeDescription.createStruct(); for (Types.NestedField nestedField : type.asStructType().fields()) { @@ -340,6 +379,38 @@ private static TypeDescription buildOrcProjectForStructType(Integer fieldId, Typ return orcType; } + private static TypeDescription getOrcSchemaForUnionType(Type type, boolean isRequired, Map mapping, + OrcField orcField) { + TypeDescription orcType; + if (orcField.type.getChildren().size() == 1) { // single type union + orcType = TypeDescription.createUnion(); + + TypeDescription childOrcStructType = TypeDescription.createStruct(); + for (Types.NestedField nestedField : type.asStructType().fields()) { + if (mapping.get(nestedField.fieldId()) == null && nestedField.hasDefaultValue()) { + continue; + } + String name = Optional.ofNullable(mapping.get(nestedField.fieldId())) + .map(OrcField::name) + .orElseGet(() -> nestedField.name()); + TypeDescription childType = buildOrcProjection(nestedField.fieldId(), nestedField.type(), + isRequired && nestedField.isRequired(), mapping); + childOrcStructType.addField(name, childType); + } + + orcType.addUnionChild(childOrcStructType); + } else { // complex union + orcType = TypeDescription.createUnion(); + List nestedFields = type.asStructType().fields(); + for (Types.NestedField nestedField : nestedFields.subList(1, nestedFields.size())) { + TypeDescription childType = buildOrcProjection(nestedField.fieldId(), nestedField.type(), + isRequired && nestedField.isRequired(), mapping); + orcType.addUnionChild(childType); + } + } + return orcType; + } + private static Map icebergToOrcMapping(String name, TypeDescription orcType) { Map icebergToOrc = Maps.newHashMap(); switch (orcType.getCategory()) { diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java index 87679d140b..c23c86d8ad 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaVisitor.java @@ -49,14 +49,20 @@ public static T visit(TypeDescription schema, OrcSchemaVisitor visitor) { case UNION: List types = schema.getChildren(); List options = Lists.newArrayListWithExpectedSize(types.size()); - for (int i = 0; i < types.size(); i++) { - visitor.beforeUnionOption(types.get(i), i); - try { - options.add(visit(types.get(i), visitor)); - } finally { - visitor.afterUnionOption(types.get(i), i); + + if (types.size() == 1) { + options.add(visit(types.get(0), visitor)); + } else { + for (int i = 0; i < types.size(); i++) { + visitor.beforeUnionOption(types.get(i), i); + try { + options.add(visit(types.get(i), visitor)); + } finally { + visitor.afterUnionOption(types.get(i), i); + } } } + return visitor.union(schema, options); case LIST: diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaWithTypeVisitor.java b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaWithTypeVisitor.java index c30eea2483..8ce309e6cf 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaWithTypeVisitor.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcSchemaWithTypeVisitor.java @@ -75,8 +75,12 @@ protected T visitUnion(Type type, TypeDescription union, OrcSchemaWithTypeVisito List types = union.getChildren(); List options = Lists.newArrayListWithCapacity(types.size()); - for (int i = 0; i < types.size(); i += 1) { - options.add(visit(type.asStructType().fields().get(i + 1).type(), types.get(i), visitor)); + if (types.size() == 1) { // single type union + options.add(visit(type, types.get(0), visitor)); + } else { // complex union + for (int i = 0; i < types.size(); i += 1) { + options.add(visit(type.asStructType().fields().get(i + 1).type(), types.get(i), visitor)); + } } return visitor.union(type, union, options); diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java index f3f1c8df05..eaede8262d 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java @@ -164,7 +164,7 @@ protected void set(InternalRow struct, int pos, Object value) { } } - static class UnionReader implements OrcValueReader { + static class UnionReader implements OrcValueReader { private final OrcValueReader[] readers; private UnionReader(List> readers) { @@ -175,20 +175,23 @@ private UnionReader(List> readers) { } @Override - public InternalRow nonNullRead(ColumnVector vector, int row) { - InternalRow struct = new GenericInternalRow(readers.length + 1); + public Object nonNullRead(ColumnVector vector, int row) { UnionColumnVector unionColumnVector = (UnionColumnVector) vector; - int fieldIndex = unionColumnVector.tags[row]; Object value = this.readers[fieldIndex].read(unionColumnVector.fields[fieldIndex], row); - for (int i = 0; i < readers.length; i += 1) { - struct.setNullAt(i + 1); + if (readers.length == 1) { + return value; + } else { + InternalRow struct = new GenericInternalRow(readers.length + 1); + for (int i = 0; i < readers.length; i += 1) { + struct.setNullAt(i + 1); + } + struct.update(0, fieldIndex); + struct.update(fieldIndex + 1, value); + + return struct; } - struct.update(0, fieldIndex); - struct.update(fieldIndex + 1, value); - - return struct; } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java index 3cddbfe479..edcc006dd1 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcUnions.java @@ -238,6 +238,223 @@ public void testDeeplyNestedUnion() throws IOException { } } + @Test + public void testSingleTypeUnion() throws IOException { + TypeDescription orcSchema = + TypeDescription.fromString("struct>"); + + Schema expectedSchema = new Schema(Types.NestedField.optional(0, "unionCol", Types.StringType.get())); + + final InternalRow expectedFirstRow = new GenericInternalRow(1); + expectedFirstRow.update(0, UTF8String.fromString("foo-0")); + + final InternalRow expectedSecondRow = new GenericInternalRow(1); + expectedSecondRow.update(0, UTF8String.fromString("foo-1")); + + Configuration conf = new Configuration(); + + File orcFile = temp.newFile(); + Path orcFilePath = new Path(orcFile.getPath()); + + Writer writer = OrcFile.createWriter(orcFilePath, + OrcFile.writerOptions(conf) + .setSchema(orcSchema).overwrite(true)); + + VectorizedRowBatch batch = orcSchema.createRowBatch(); + BytesColumnVector bytesColumnVector = new BytesColumnVector(NUM_OF_ROWS); + UnionColumnVector complexUnion = new UnionColumnVector(NUM_OF_ROWS, bytesColumnVector); + + complexUnion.init(); + + for (int i = 0; i < NUM_OF_ROWS; i += 1) { + complexUnion.tags[i] = 0; + String stringValue = "foo-" + i; + bytesColumnVector.setVal(i, stringValue.getBytes(StandardCharsets.UTF_8)); + } + + batch.size = NUM_OF_ROWS; + batch.cols[0] = complexUnion; + + writer.addRowBatch(batch); + batch.reset(); + writer.close(); + + // Test non-vectorized reader + List actualRows = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(expectedSchema, readOrcSchema)) + .build()) { + reader.forEach(actualRows::add); + + Assert.assertEquals(actualRows.size(), NUM_OF_ROWS); + assertEquals(expectedSchema, expectedFirstRow, actualRows.get(0)); + assertEquals(expectedSchema, expectedSecondRow, actualRows.get(1)); + } + + // Test vectorized reader + /* + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRowsIt = batchesToRows(reader.iterator()); + + assertEquals(expectedSchema, expectedFirstRow, actualRowsIt.next()); + assertEquals(expectedSchema, expectedSecondRow, actualRowsIt.next()); + } + */ + } + + @Test + public void testSingleTypeUnionOfStruct() throws IOException { + TypeDescription orcSchema = + TypeDescription.fromString("struct>>"); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(0, "unionCol", Types.StructType.of( + Types.NestedField.optional(1, "c", Types.StringType.get()) + ))); + + final InternalRow expectedFirstRow = new GenericInternalRow(1); + final InternalRow innerExpectedFirstRow = new GenericInternalRow(1); + innerExpectedFirstRow.update(0, UTF8String.fromString("foo-0")); + expectedFirstRow.update(0, innerExpectedFirstRow); + + final InternalRow expectedSecondRow = new GenericInternalRow(1); + final InternalRow innerExpectedSecondRow = new GenericInternalRow(1); + innerExpectedSecondRow.update(0, UTF8String.fromString("foo-1")); + expectedSecondRow.update(0, innerExpectedSecondRow); + + Configuration conf = new Configuration(); + + File orcFile = temp.newFile(); + Path orcFilePath = new Path(orcFile.getPath()); + + Writer writer = OrcFile.createWriter(orcFilePath, + OrcFile.writerOptions(conf) + .setSchema(orcSchema).overwrite(true)); + + VectorizedRowBatch batch = orcSchema.createRowBatch(); + UnionColumnVector complexUnion = (UnionColumnVector) batch.cols[0]; + StructColumnVector structColumnVector = (StructColumnVector) complexUnion.fields[0]; + BytesColumnVector bytesColumnVector = (BytesColumnVector) structColumnVector.fields[0]; + + for (int i = 0; i < NUM_OF_ROWS; i += 1) { + complexUnion.tags[i] = 0; + String stringValue = "foo-" + i; + bytesColumnVector.setVal(i, stringValue.getBytes(StandardCharsets.UTF_8)); + } + + batch.size = NUM_OF_ROWS; + writer.addRowBatch(batch); + batch.reset(); + writer.close(); + + // Test non-vectorized reader + List actualRows = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(expectedSchema, readOrcSchema)) + .build()) { + reader.forEach(actualRows::add); + + Assert.assertEquals(actualRows.size(), NUM_OF_ROWS); + assertEquals(expectedSchema, expectedFirstRow, actualRows.get(0)); + assertEquals(expectedSchema, expectedSecondRow, actualRows.get(1)); + } + + // Test vectorized reader + /* + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRowsIt = batchesToRows(reader.iterator()); + + assertEquals(expectedSchema, expectedFirstRow, actualRowsIt.next()); + assertEquals(expectedSchema, expectedSecondRow, actualRowsIt.next()); + } + */ + } + + @Test + public void testDeepNestedSingleTypeUnion() throws IOException { + TypeDescription orcSchema = + TypeDescription.fromString("struct>>>"); + + Schema expectedSchema = new Schema( + Types.NestedField.optional(0, "outerUnion", Types.StructType.of( + Types.NestedField.optional(1, "innerUnion", Types.StringType.get()) + ))); + + final InternalRow expectedFirstRow = new GenericInternalRow(1); + final InternalRow innerExpectedFirstRow = new GenericInternalRow(1); + innerExpectedFirstRow.update(0, UTF8String.fromString("foo-0")); + expectedFirstRow.update(0, innerExpectedFirstRow); + + final InternalRow expectedSecondRow = new GenericInternalRow(1); + final InternalRow innerExpectedSecondRow = new GenericInternalRow(1); + innerExpectedSecondRow.update(0, UTF8String.fromString("foo-1")); + expectedSecondRow.update(0, innerExpectedSecondRow); + + Configuration conf = new Configuration(); + + File orcFile = temp.newFile(); + Path orcFilePath = new Path(orcFile.getPath()); + + Writer writer = OrcFile.createWriter(orcFilePath, + OrcFile.writerOptions(conf) + .setSchema(orcSchema).overwrite(true)); + + VectorizedRowBatch batch = orcSchema.createRowBatch(); + UnionColumnVector outerUnion = (UnionColumnVector) batch.cols[0]; + StructColumnVector structColumnVector = (StructColumnVector) outerUnion.fields[0]; + UnionColumnVector innerUnion = (UnionColumnVector) structColumnVector.fields[0]; + BytesColumnVector bytesColumnVector = (BytesColumnVector) innerUnion.fields[0]; + + for (int i = 0; i < NUM_OF_ROWS; i += 1) { + outerUnion.tags[i] = 0; + innerUnion.tags[i] = 0; + String stringValue = "foo-" + i; + bytesColumnVector.setVal(i, stringValue.getBytes(StandardCharsets.UTF_8)); + } + + batch.size = NUM_OF_ROWS; + writer.addRowBatch(batch); + batch.reset(); + writer.close(); + + // Test non-vectorized reader + List actualRows = Lists.newArrayList(); + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createReaderFunc(readOrcSchema -> new SparkOrcReader(expectedSchema, readOrcSchema)) + .build()) { + reader.forEach(actualRows::add); + + Assert.assertEquals(actualRows.size(), NUM_OF_ROWS); + assertEquals(expectedSchema, expectedFirstRow, actualRows.get(0)); + assertEquals(expectedSchema, expectedSecondRow, actualRows.get(1)); + } + + // Test vectorized reader + /* + try (CloseableIterable reader = ORC.read(Files.localInput(orcFile)) + .project(expectedSchema) + .createBatchedReaderFunc(readOrcSchema -> + VectorizedSparkOrcReaders.buildReader(expectedSchema, readOrcSchema, ImmutableMap.of())) + .build()) { + final Iterator actualRowsIt = batchesToRows(reader.iterator()); + + assertEquals(expectedSchema, expectedFirstRow, actualRowsIt.next()); + assertEquals(expectedSchema, expectedSecondRow, actualRowsIt.next()); + } + */ + } + private Iterator batchesToRows(Iterator batches) { return Iterators.concat(Iterators.transform(batches, ColumnarBatch::rowIterator)); }