From 5cf92487d5e64764b75284867471acb4e3cddb4d Mon Sep 17 00:00:00 2001 From: Constantin Muraru Date: Sat, 29 Apr 2017 22:33:43 +0300 Subject: [PATCH 1/3] PARQUET-968 Add Hive support in ProtoParquet Fix bug with map writer Implement review Implement review PARQUET-968 Implement feedback Update the proto to parquet schema converter for MAP fields so that it follows the scec: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists This came as feedback from the Amazon Athena team. --- .../parquet/proto/ProtoMessageConverter.java | 131 ++++++++++++++++- .../parquet/proto/ProtoSchemaConverter.java | 139 ++++++++++++++---- .../parquet/proto/ProtoWriteSupport.java | 95 +++++++++++- .../proto/ProtoSchemaConverterTest.java | 38 +++-- .../parquet/proto/ProtoWriteSupportTest.java | 48 ++++++ 5 files changed, 401 insertions(+), 50 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index b5649a05b6..953994f1c1 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -24,12 +24,14 @@ import com.twitter.elephantbird.util.Protobufs; import org.apache.parquet.column.Dictionary; import org.apache.parquet.io.InvalidRecordException; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.Type; import java.util.HashMap; @@ -129,10 +131,14 @@ public void add(Object value) { }; } - return newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType); + OriginalType originalType = parquetType.getOriginalType() == null ? OriginalType.UTF8 : parquetType.getOriginalType(); + switch (originalType) { + case LIST: return new ListConverter(parentBuilder, fieldDescriptor, parquetType); + case MAP: return new MapConverter(parentBuilder, fieldDescriptor, parquetType); + default: return newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType); + } } - private Converter newScalarConverter(ParentValueContainer pvc, Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { JavaType javaType = fieldDescriptor.getJavaType(); @@ -345,4 +351,125 @@ public void addBinary(Binary binary) { } } + + /** + * This class unwraps the additional LIST wrapper and makes it possible to read the underlying data and then convert + * it to protobuf. + *

+ * Consider the following protobuf schema: + * message SimpleList { + * repeated int64 first_array = 1; + * } + *

+ * A LIST wrapper is created in parquet for the above mentioned protobuf schema: + * message SimpleList { + * required group first_array (LIST) = 1 { + * repeated int32 element; + * } + * } + *

+ * The LIST wrappers are used by 3rd party tools, such as Hive, to read parquet arrays. The wrapper contains + * one only one field: either a primitive field (like in the example above, where we have an array of ints) or + * another group (array of messages). + */ + final class ListConverter extends GroupConverter { + private final Converter converter; + private final boolean listOfMessage; + + public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { + OriginalType originalType = parquetType.getOriginalType(); + if (originalType != OriginalType.LIST) { + throw new ParquetDecodingException("Expected LIST wrapper. Found: " + originalType + " instead."); + } + + listOfMessage = fieldDescriptor.getJavaType() == JavaType.MESSAGE; + + Type parquetSchema; + if (parquetType.asGroupType().containsField("list")) { + parquetSchema = parquetType.asGroupType().getType("list"); + if (parquetSchema.asGroupType().containsField("element")) { + parquetSchema.asGroupType().getType("element"); + } + } else { + throw new ParquetDecodingException("Expected list but got: " + parquetType); + } + + converter = newMessageConverter(parentBuilder, fieldDescriptor, parquetSchema); + } + + @Override + public Converter getConverter(int fieldIndex) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); + } + + if (listOfMessage) { + return converter; + } + + return new GroupConverter() { + @Override + public Converter getConverter(int fieldIndex) { + return converter; + } + + @Override + public void start() { + + } + + @Override + public void end() { + + } + }; + } + + @Override + public void start() { + + } + + @Override + public void end() { + + } + } + + + final class MapConverter extends GroupConverter { + private final Converter converter; + + public MapConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { + OriginalType originalType = parquetType.getOriginalType(); + if (originalType != OriginalType.MAP) { + throw new ParquetDecodingException("Expected MAP wrapper. Found: " + originalType + " instead."); + } + + Type parquetSchema; + if (parquetType.asGroupType().containsField("key_value")){ + parquetSchema = parquetType.asGroupType().getType("key_value"); + } else { + throw new ParquetDecodingException("Expected map but got: " + parquetType); + } + + converter = newMessageConverter(parentBuilder, fieldDescriptor, parquetSchema); + } + + @Override + public Converter getConverter(int fieldIndex) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the MAP wrapper"); + } + return converter; + } + + @Override + public void start() { + } + + @Override + public void end() { + } + } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index 2c4a1caeec..f3dd11db38 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -18,30 +18,26 @@ */ package org.apache.parquet.proto; -import static org.apache.parquet.schema.OriginalType.ENUM; -import static org.apache.parquet.schema.OriginalType.UTF8; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; -import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; - -import java.util.List; - +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; +import com.google.protobuf.Message; +import com.twitter.elephantbird.util.Protobufs; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; import org.apache.parquet.schema.Types.Builder; import org.apache.parquet.schema.Types.GroupBuilder; - -import com.google.protobuf.Descriptors; -import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; -import com.google.protobuf.Message; -import com.twitter.elephantbird.util.Protobufs; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.List; + +import static org.apache.parquet.schema.OriginalType.ENUM; +import static org.apache.parquet.schema.OriginalType.UTF8; +import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.*; + /** *

* Converts a Protocol Buffer Descriptor into a Parquet schema. @@ -83,26 +79,105 @@ private Type.Repetition getRepetition(Descriptors.FieldDescriptor descriptor) { } } - private Builder>, GroupBuilder> addField(Descriptors.FieldDescriptor descriptor, GroupBuilder builder) { - Type.Repetition repetition = getRepetition(descriptor); - JavaType javaType = descriptor.getJavaType(); + private Builder>, GroupBuilder> addField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + if (descriptor.getJavaType() == JavaType.MESSAGE) { + return addMessageField(descriptor, builder); + } + + ParquetType parquetType = getParquetType(descriptor); + if (descriptor.isRepeated()) { + return addRepeatedPrimitive(descriptor, parquetType.primitiveType, parquetType.originalType, builder); + } + + return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.originalType); + } + + private Builder>, GroupBuilder> addRepeatedPrimitive(Descriptors.FieldDescriptor descriptor, + PrimitiveTypeName primitiveType, + OriginalType originalType, + final GroupBuilder builder) { + return builder + .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) + .group(Type.Repetition.REPEATED) + .primitive(primitiveType, Type.Repetition.REQUIRED).as(originalType) + .named("element") + .named("list"); + } + + private GroupBuilder> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder builder) { + GroupBuilder>> result = + builder + .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) + .group(Type.Repetition.REPEATED); + + convertFields(result, descriptor.getMessageType().getFields()); + + return result.named("list"); + } + + private GroupBuilder> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + if (descriptor.isMapField()) { + return addMapField(descriptor, builder); + } else if (descriptor.isRepeated()) { + return addRepeatedMessage(descriptor, builder); + } + + // Plain message + GroupBuilder> group = builder.group(getRepetition(descriptor)); + convertFields(group, descriptor.getMessageType().getFields()); + return group; + } + + private GroupBuilder> addMapField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + List fields = descriptor.getMessageType().getFields(); + if (fields.size() != 2) { + throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); + } + + ParquetType mapKeyParquetType = getParquetType(fields.get(0)); + + GroupBuilder>> group = builder + .group(Type.Repetition.REQUIRED).as(OriginalType.MAP) + .group(Type.Repetition.REPEATED) // key_value wrapper + .primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key"); + + return addField(fields.get(1), group).named("value") + .named("key_value"); + } + + private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) { + + JavaType javaType = fieldDescriptor.getJavaType(); switch (javaType) { - case BOOLEAN: return builder.primitive(BOOLEAN, repetition); - case INT: return builder.primitive(INT32, repetition); - case LONG: return builder.primitive(INT64, repetition); - case FLOAT: return builder.primitive(FLOAT, repetition); - case DOUBLE: return builder.primitive(DOUBLE, repetition); - case BYTE_STRING: return builder.primitive(BINARY, repetition); - case STRING: return builder.primitive(BINARY, repetition).as(UTF8); - case MESSAGE: { - GroupBuilder> group = builder.group(repetition); - convertFields(group, descriptor.getMessageType().getFields()); - return group; - } - case ENUM: return builder.primitive(BINARY, repetition).as(ENUM); + case INT: return ParquetType.of(INT32); + case LONG: return ParquetType.of(INT64); + case DOUBLE: return ParquetType.of(DOUBLE); + case BOOLEAN: return ParquetType.of(BOOLEAN); + case FLOAT: return ParquetType.of(FLOAT); + case STRING: return ParquetType.of(BINARY, UTF8); + case ENUM: return ParquetType.of(BINARY, ENUM); + case BYTE_STRING: return ParquetType.of(BINARY); default: throw new UnsupportedOperationException("Cannot convert Protocol Buffer: unknown type " + javaType); } } + private static class ParquetType { + PrimitiveTypeName primitiveType; + OriginalType originalType; + + private ParquetType(PrimitiveTypeName primitiveType, OriginalType originalType) { + this.primitiveType = primitiveType; + this.originalType = originalType; + } + + public static ParquetType of(PrimitiveTypeName primitiveType, OriginalType originalType) { + return new ParquetType(primitiveType, originalType); + } + + public static ParquetType of(PrimitiveTypeName primitiveType) { + return of(primitiveType, null); + } + } + } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index c0ed351046..8e2b4aeb44 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -21,6 +21,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; +import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.TextFormat; @@ -34,11 +35,13 @@ import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.reflect.Array; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -162,7 +165,7 @@ class MessageWriter extends FieldWriter { Type type = schema.getType(name); FieldWriter writer = createWriter(fieldDescriptor, type); - if(fieldDescriptor.isRepeated()) { + if(fieldDescriptor.isRepeated() && !fieldDescriptor.isMapField()) { writer = new ArrayWriter(writer); } @@ -177,7 +180,7 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty switch (fieldDescriptor.getJavaType()) { case STRING: return new StringWriter() ; - case MESSAGE: return new MessageWriter(fieldDescriptor.getMessageType(), type.asGroupType()); + case MESSAGE: return createMessageWriter(fieldDescriptor, type); case INT: return new IntWriter(); case LONG: return new LongWriter(); case FLOAT: return new FloatWriter(); @@ -190,6 +193,47 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty return unknownType(fieldDescriptor);//should not be executed, always throws exception. } + private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + if (fieldDescriptor.isMapField()) { + return createMapWriter(fieldDescriptor, type); + } + + return new MessageWriter(fieldDescriptor.getMessageType(), getGroupType(type)); + } + + private GroupType getGroupType(Type type) { + if (type.getOriginalType() == OriginalType.LIST) { + return type.asGroupType().getType("list").asGroupType(); + } + + if (type.getOriginalType() == OriginalType.MAP) { + return type.asGroupType().getType("key_value").asGroupType().getType("value").asGroupType(); + } + + return type.asGroupType(); + } + + private MapWriter createMapWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + List fields = fieldDescriptor.getMessageType().getFields(); + if (fields.size() != 2) { + throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); + } + + // KeyFieldWriter + Descriptors.FieldDescriptor keyProtoField = fields.get(0); + FieldWriter keyWriter = createWriter(keyProtoField, type); + keyWriter.setFieldName(keyProtoField.getName()); + keyWriter.setIndex(0); + + // ValueFieldWriter + Descriptors.FieldDescriptor valueProtoField = fields.get(1); + FieldWriter valueWriter = createWriter(valueProtoField, type); + valueWriter.setFieldName(valueProtoField.getName()); + valueWriter.setIndex(1); + + return new MapWriter(keyWriter, valueWriter); + } + /** Writes top level message. It cannot call startGroup() */ void writeTopLevelMessage(Object value) { writeAllFields((MessageOrBuilder) value); @@ -198,9 +242,7 @@ void writeTopLevelMessage(Object value) { /** Writes message as part of repeated field. It cannot start field*/ @Override final void writeRawValue(Object value) { - recordConsumer.startGroup(); writeAllFields((MessageOrBuilder) value); - recordConsumer.endGroup(); } /** Used for writing nonrepeated (optional, required) fields*/ @@ -247,16 +289,32 @@ final void writeRawValue(Object value) { @Override final void writeField(Object value) { recordConsumer.startField(fieldName, index); + recordConsumer.startGroup(); List list = (List) value; + recordConsumer.startField("list", 0); // This is the wrapper group for the array field for (Object listEntry: list) { + recordConsumer.startGroup(); + if (isPrimitive(listEntry)) { + recordConsumer.startField("element", 0); + } fieldWriter.writeRawValue(listEntry); + if (isPrimitive(listEntry)) { + recordConsumer.endField("element", 0); + } + recordConsumer.endGroup(); } + recordConsumer.endField("list", 0); + recordConsumer.endGroup(); recordConsumer.endField(fieldName, index); } } + private boolean isPrimitive(Object listEntry) { + return !(listEntry instanceof Message); + } + /** validates mapping between protobuffer fields and parquet fields.*/ private void validatedMapping(Descriptors.Descriptor descriptor, GroupType parquetSchema) { List allFields = descriptor.getFields(); @@ -296,6 +354,35 @@ final void writeRawValue(Object value) { } } + class MapWriter extends FieldWriter { + + private final FieldWriter keyWriter; + private final FieldWriter valueWriter; + + public MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) { + super(); + this.keyWriter = keyWriter; + this.valueWriter = valueWriter; + } + + @Override + final void writeRawValue(Object value) { + recordConsumer.startGroup(); + + recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field + for(MapEntry entry : (Collection>) value) { + recordConsumer.startGroup(); + keyWriter.writeField(entry.getKey()); + valueWriter.writeField(entry.getValue()); + recordConsumer.endGroup(); + } + + recordConsumer.endField("key_value", 0); + + recordConsumer.endGroup(); + } + } + class FloatWriter extends FieldWriter { @Override final void writeRawValue(Object value) { diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 6f5ff53b69..70bc1f79d9 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -103,10 +103,12 @@ public void testProto3ConvertAllDatatypes() throws Exception { " optional binary optionalEnum (ENUM) = 18;" + " optional int32 someInt32 = 19;" + " optional binary someString (UTF8) = 20;" + - " repeated group optionalMap = 21 {\n" + - " optional int64 key = 1;\n" + - " optional group value = 2 {\n" + - " optional int32 someId = 3;\n" + + " required group optionalMap (MAP) = 21 {\n" + + " repeated group key_value {\n" + + " required int64 key;\n" + + " optional group value {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }\n" + "}"; @@ -120,16 +122,22 @@ public void testConvertRepetition() throws Exception { "message TestProtobuf.SchemaConverterRepetition {\n" + " optional int32 optionalPrimitive = 1;\n" + " required int32 requiredPrimitive = 2;\n" + - " repeated int32 repeatedPrimitive = 3;\n" + + " required group repeatedPrimitive (LIST) = 3 {\n" + + " repeated group list {\n" + + " required int32 element;\n" + + " }\n" + + " }\n" + " optional group optionalMessage = 7 {\n" + " optional int32 someId = 3;\n" + " }\n" + - " required group requiredMessage = 8 {" + + " required group requiredMessage = 8 {\n" + " optional int32 someId= 3;\n" + " }\n" + - " repeated group repeatedMessage = 9 {" + - " optional int32 someId = 3;\n" + - " }\n" + + " required group repeatedMessage (LIST) = 9 {\n" + + " repeated group list {\n" + + " optional int32 someId = 3;\n" + + " }\n" + + " }" + "}"; testConversion(TestProtobuf.SchemaConverterRepetition.class, expectedSchema); @@ -140,12 +148,18 @@ public void testProto3ConvertRepetition() throws Exception { String expectedSchema = "message TestProto3.SchemaConverterRepetition {\n" + " optional int32 optionalPrimitive = 1;\n" + - " repeated int32 repeatedPrimitive = 3;\n" + + " required group repeatedPrimitive (LIST) = 3 {\n" + + " repeated group list {\n" + + " required int32 element;\n" + + " }\n" + + " }\n" + " optional group optionalMessage = 7 {\n" + " optional int32 someId = 3;\n" + " }\n" + - " repeated group repeatedMessage = 9 {" + - " optional int32 someId = 3;\n" + + " required group repeatedMessage (LIST) = 9 {\n" + + " repeated group list {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + "}"; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index b937618c3b..e00facfc06 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -94,8 +94,23 @@ public void testRepeatedIntMessage() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("repeatedInt", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).addInteger(1323); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).addInteger(54469); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("repeatedInt", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); @@ -116,8 +131,23 @@ public void testProto3RepeatedIntMessage() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("repeatedInt", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).addInteger(1323); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).addInteger(54469); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("repeatedInt", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); @@ -138,6 +168,8 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); @@ -145,6 +177,8 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); @@ -165,6 +199,8 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); @@ -172,6 +208,8 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); @@ -192,6 +230,9 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + //first inner message inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); @@ -206,6 +247,8 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); @@ -226,6 +269,9 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("list", 0); + //first inner message inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); @@ -240,6 +286,8 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); inOrder.verify(readConsumerMock).endMessage(); Mockito.verifyNoMoreInteractions(readConsumerMock); From a8bd704138e9a3d99d96752766b80b5fb576802e Mon Sep 17 00:00:00 2001 From: Constantin Muraru Date: Wed, 6 Sep 2017 00:06:11 +0300 Subject: [PATCH 2/3] Pick up commit from @andredasilvapinto https://github.com/andredasilvapinto/parquet-mr/commit/dfa9701a4d843bb7cd1d429d86d17811b735f33c --- .../parquet/proto/ProtoMessageConverter.java | 6 +- .../parquet/proto/ProtoSchemaConverter.java | 32 ++++---- .../parquet/proto/ProtoWriteSupport.java | 81 ++++++++++--------- .../proto/ProtoSchemaConverterTest.java | 10 ++- .../parquet/proto/ProtoWriteSupportTest.java | 28 +++++++ 5 files changed, 97 insertions(+), 60 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 953994f1c1..bb9930b6fe 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -388,7 +388,7 @@ public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor if (parquetType.asGroupType().containsField("list")) { parquetSchema = parquetType.asGroupType().getType("list"); if (parquetSchema.asGroupType().containsField("element")) { - parquetSchema.asGroupType().getType("element"); + parquetSchema = parquetSchema.asGroupType().getType("element"); } } else { throw new ParquetDecodingException("Expected list but got: " + parquetType); @@ -403,10 +403,6 @@ public Converter getConverter(int fieldIndex) { throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); } - if (listOfMessage) { - return converter; - } - return new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index f3dd11db38..eae54ebefa 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -19,6 +19,7 @@ package org.apache.parquet.proto; import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; import com.google.protobuf.Message; import com.twitter.elephantbird.util.Protobufs; @@ -59,8 +60,8 @@ public MessageType convert(Class protobufClass) { } /* Iterates over list of fields. **/ - private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { - for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) { + private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { + for (FieldDescriptor fieldDescriptor : fieldDescriptors) { groupBuilder = addField(fieldDescriptor, groupBuilder) .id(fieldDescriptor.getNumber()) @@ -69,7 +70,7 @@ private GroupBuilder convertFields(GroupBuilder groupBuilder, List Builder>, GroupBuilder> addField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + private Builder>, GroupBuilder> addField(FieldDescriptor descriptor, final GroupBuilder builder) { if (descriptor.getJavaType() == JavaType.MESSAGE) { return addMessageField(descriptor, builder); } @@ -92,7 +93,7 @@ private Builder>, GroupBuilder> addF return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.originalType); } - private Builder>, GroupBuilder> addRepeatedPrimitive(Descriptors.FieldDescriptor descriptor, + private Builder>, GroupBuilder> addRepeatedPrimitive(FieldDescriptor descriptor, PrimitiveTypeName primitiveType, OriginalType originalType, final GroupBuilder builder) { @@ -104,18 +105,19 @@ private Builder>, GroupBuilder> addR .named("list"); } - private GroupBuilder> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder builder) { - GroupBuilder>> result = + private GroupBuilder> addRepeatedMessage(FieldDescriptor descriptor, GroupBuilder builder) { + GroupBuilder>>> result = builder .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) - .group(Type.Repetition.REPEATED); + .group(Type.Repetition.REPEATED) + .group(Type.Repetition.OPTIONAL); convertFields(result, descriptor.getMessageType().getFields()); - return result.named("list"); + return result.named("element").named("list"); } - private GroupBuilder> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + private GroupBuilder> addMessageField(FieldDescriptor descriptor, final GroupBuilder builder) { if (descriptor.isMapField()) { return addMapField(descriptor, builder); } else if (descriptor.isRepeated()) { @@ -128,8 +130,8 @@ private GroupBuilder> addMessageField(Descriptors.FieldDescr return group; } - private GroupBuilder> addMapField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { - List fields = descriptor.getMessageType().getFields(); + private GroupBuilder> addMapField(FieldDescriptor descriptor, final GroupBuilder builder) { + List fields = descriptor.getMessageType().getFields(); if (fields.size() != 2) { throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); } @@ -137,7 +139,7 @@ private GroupBuilder> addMapField(Descriptors.FieldDescripto ParquetType mapKeyParquetType = getParquetType(fields.get(0)); GroupBuilder>> group = builder - .group(Type.Repetition.REQUIRED).as(OriginalType.MAP) + .group(Type.Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3 .group(Type.Repetition.REPEATED) // key_value wrapper .primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key"); @@ -145,7 +147,7 @@ private GroupBuilder> addMapField(Descriptors.FieldDescripto .named("key_value"); } - private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) { + private ParquetType getParquetType(FieldDescriptor fieldDescriptor) { JavaType javaType = fieldDescriptor.getJavaType(); switch (javaType) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 8e2b4aeb44..bb75e71748 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -18,13 +18,9 @@ */ package org.apache.parquet.proto; -import com.google.protobuf.ByteString; -import com.google.protobuf.DescriptorProtos; -import com.google.protobuf.Descriptors; -import com.google.protobuf.MapEntry; -import com.google.protobuf.Message; -import com.google.protobuf.MessageOrBuilder; -import com.google.protobuf.TextFormat; +import com.google.protobuf.*; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.hadoop.BadConfigurationException; @@ -32,10 +28,7 @@ import org.apache.parquet.io.InvalidRecordException; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.RecordConsumer; -import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.IncompatibleSchemaModificationException; -import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.*; import org.apache.parquet.schema.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,7 +106,7 @@ public WriteContext init(Configuration configuration) { } MessageType rootSchema = new ProtoSchemaConverter().convert(protoMessage); - Descriptors.Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage); + Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage); validatedMapping(messageDescriptor, rootSchema); this.messageWriter = new MessageWriter(messageDescriptor, rootSchema); @@ -156,11 +149,11 @@ class MessageWriter extends FieldWriter { final FieldWriter[] fieldWriters; @SuppressWarnings("unchecked") - MessageWriter(Descriptors.Descriptor descriptor, GroupType schema) { - List fields = descriptor.getFields(); + MessageWriter(Descriptor descriptor, GroupType schema) { + List fields = descriptor.getFields(); fieldWriters = (FieldWriter[]) Array.newInstance(FieldWriter.class, fields.size()); - for (Descriptors.FieldDescriptor fieldDescriptor: fields) { + for (FieldDescriptor fieldDescriptor: fields) { String name = fieldDescriptor.getName(); Type type = schema.getType(name); FieldWriter writer = createWriter(fieldDescriptor, type); @@ -176,7 +169,7 @@ class MessageWriter extends FieldWriter { } } - private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) { switch (fieldDescriptor.getJavaType()) { case STRING: return new StringWriter() ; @@ -193,7 +186,7 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty return unknownType(fieldDescriptor);//should not be executed, always throws exception. } - private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type type) { if (fieldDescriptor.isMapField()) { return createMapWriter(fieldDescriptor, type); } @@ -203,7 +196,7 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip private GroupType getGroupType(Type type) { if (type.getOriginalType() == OriginalType.LIST) { - return type.asGroupType().getType("list").asGroupType(); + return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType(); } if (type.getOriginalType() == OriginalType.MAP) { @@ -213,20 +206,20 @@ private GroupType getGroupType(Type type) { return type.asGroupType(); } - private MapWriter createMapWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { - List fields = fieldDescriptor.getMessageType().getFields(); + private MapWriter createMapWriter(FieldDescriptor fieldDescriptor, Type type) { + List fields = fieldDescriptor.getMessageType().getFields(); if (fields.size() != 2) { throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); } // KeyFieldWriter - Descriptors.FieldDescriptor keyProtoField = fields.get(0); + FieldDescriptor keyProtoField = fields.get(0); FieldWriter keyWriter = createWriter(keyProtoField, type); keyWriter.setFieldName(keyProtoField.getName()); keyWriter.setIndex(0); // ValueFieldWriter - Descriptors.FieldDescriptor valueProtoField = fields.get(1); + FieldDescriptor valueProtoField = fields.get(1); FieldWriter valueWriter = createWriter(valueProtoField, type); valueWriter.setFieldName(valueProtoField.getName()); valueWriter.setIndex(1); @@ -257,10 +250,10 @@ final void writeField(Object value) { private void writeAllFields(MessageOrBuilder pb) { //returns changed fields with values. Map is ordered by id. - Map changedPbFields = pb.getAllFields(); + Map changedPbFields = pb.getAllFields(); - for (Map.Entry entry : changedPbFields.entrySet()) { - Descriptors.FieldDescriptor fieldDescriptor = entry.getKey(); + for (Map.Entry entry : changedPbFields.entrySet()) { + FieldDescriptor fieldDescriptor = entry.getKey(); if(fieldDescriptor.isExtension()) { // Field index of an extension field might overlap with a base field. @@ -295,13 +288,21 @@ final void writeField(Object value) { recordConsumer.startField("list", 0); // This is the wrapper group for the array field for (Object listEntry: list) { recordConsumer.startGroup(); - if (isPrimitive(listEntry)) { - recordConsumer.startField("element", 0); + + recordConsumer.startField("element", 0); // This is the mandatory inner field + + if (!isPrimitive(listEntry)) { + recordConsumer.startGroup(); } + fieldWriter.writeRawValue(listEntry); - if (isPrimitive(listEntry)) { - recordConsumer.endField("element", 0); + + if (!isPrimitive(listEntry)) { + recordConsumer.endGroup(); } + + recordConsumer.endField("element", 0); + recordConsumer.endGroup(); } recordConsumer.endField("list", 0); @@ -316,10 +317,10 @@ private boolean isPrimitive(Object listEntry) { } /** validates mapping between protobuffer fields and parquet fields.*/ - private void validatedMapping(Descriptors.Descriptor descriptor, GroupType parquetSchema) { - List allFields = descriptor.getFields(); + private void validatedMapping(Descriptor descriptor, GroupType parquetSchema) { + List allFields = descriptor.getFields(); - for (Descriptors.FieldDescriptor fieldDescriptor: allFields) { + for (FieldDescriptor fieldDescriptor: allFields) { String fieldName = fieldDescriptor.getName(); int fieldIndex = fieldDescriptor.getIndex(); int parquetIndex = parquetSchema.getFieldIndex(fieldName); @@ -370,10 +371,16 @@ final void writeRawValue(Object value) { recordConsumer.startGroup(); recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field - for(MapEntry entry : (Collection>) value) { + for (Message msg : (Collection) value) { recordConsumer.startGroup(); - keyWriter.writeField(entry.getKey()); - valueWriter.writeField(entry.getValue()); + + final Descriptor descriptorForType = msg.getDescriptorForType(); + final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key"); + final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value"); + + keyWriter.writeField(msg.getField(keyDesc)); + valueWriter.writeField(msg.getField(valueDesc)); + recordConsumer.endGroup(); } @@ -421,7 +428,7 @@ final void writeRawValue(Object value) { } } - private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) { + private FieldWriter unknownType(FieldDescriptor fieldDescriptor) { String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor + "\" and type \"" + fieldDescriptor.getJavaType() + "\"."; throw new InvalidRecordException(exceptionMsg); @@ -429,7 +436,7 @@ private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) { /** Returns message descriptor as JSON String*/ private String serializeDescriptor(Class protoClass) { - Descriptors.Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass); + Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass); DescriptorProtos.DescriptorProto asProto = descriptor.toProto(); return TextFormat.printToString(asProto); } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 70bc1f79d9..d7ec169ce7 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -103,7 +103,7 @@ public void testProto3ConvertAllDatatypes() throws Exception { " optional binary optionalEnum (ENUM) = 18;" + " optional int32 someInt32 = 19;" + " optional binary someString (UTF8) = 20;" + - " required group optionalMap (MAP) = 21 {\n" + + " optional group optionalMap (MAP) = 21 {\n" + " repeated group key_value {\n" + " required int64 key;\n" + " optional group value {\n" + @@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }" + "}"; @@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }\n" + "}"; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index e00facfc06..de27ebf3f4 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -169,6 +169,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); @@ -177,6 +180,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); @@ -201,12 +207,18 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); inOrder.verify(readConsumerMock).startGroup(); + + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -235,17 +247,25 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -274,17 +294,25 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); From 16eafcb6d124af1a1deed68e84723e8ad1d91261 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Hanotte?= Date: Fri, 13 Apr 2018 15:25:34 +0200 Subject: [PATCH 3/3] PARQUET-968 add proto flag to enable writing using specs-compliant schemas (#2) * PARQUET-968 Add flag to write using specs compliant schemas For users that require backward compatibility with parquet 1.9.0 and older, the flag "parquet.proto.writeSpecsCompliant" is introduced to allow writing collection using the old style (using repeated and not using the LIST and MAP wrappers that are recommended by the parquet specs). * PARQUET-968 Add InputOutputFormat tests to validate read/write --- .../parquet/proto/ProtoMessageConverter.java | 43 +- .../parquet/proto/ProtoSchemaConverter.java | 30 +- .../parquet/proto/ProtoWriteSupport.java | 76 ++- .../proto/ProtoInputOutputFormatTest.java | 120 +++++ .../proto/ProtoSchemaConverterTest.java | 185 ++++++- .../parquet/proto/ProtoWriteSupportTest.java | 501 +++++++++++++++++- .../parquet/proto/utils/WriteUsingMR.java | 10 +- .../src/test/resources/TestProto3.proto | 8 + .../src/test/resources/TestProtobuf.proto | 8 + 9 files changed, 920 insertions(+), 61 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index bb9930b6fe..d5f43e6b35 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -131,12 +131,13 @@ public void add(Object value) { }; } - OriginalType originalType = parquetType.getOriginalType() == null ? OriginalType.UTF8 : parquetType.getOriginalType(); - switch (originalType) { - case LIST: return new ListConverter(parentBuilder, fieldDescriptor, parquetType); - case MAP: return new MapConverter(parentBuilder, fieldDescriptor, parquetType); - default: return newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType); + if (OriginalType.LIST == parquetType.getOriginalType()) { + return new ListConverter(parentBuilder, fieldDescriptor, parquetType); } + if (OriginalType.MAP == parquetType.getOriginalType()) { + return new MapConverter(parentBuilder, fieldDescriptor, parquetType); + } + return newScalarConverter(parent, parentBuilder, fieldDescriptor, parquetType); } private Converter newScalarConverter(ParentValueContainer pvc, Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { @@ -363,38 +364,38 @@ public void addBinary(Binary binary) { *

* A LIST wrapper is created in parquet for the above mentioned protobuf schema: * message SimpleList { - * required group first_array (LIST) = 1 { - * repeated int32 element; + * optional group first_array (LIST) = 1 { + * repeated group list { + * optional int32 element; + * } * } * } *

* The LIST wrappers are used by 3rd party tools, such as Hive, to read parquet arrays. The wrapper contains - * one only one field: either a primitive field (like in the example above, where we have an array of ints) or - * another group (array of messages). + * a repeated group named 'list', itself containing only one field called 'element' of the type of the repeated + * object (can be a primitive as in this example or a group in case of a repeated message in protobuf). */ final class ListConverter extends GroupConverter { private final Converter converter; - private final boolean listOfMessage; public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { OriginalType originalType = parquetType.getOriginalType(); - if (originalType != OriginalType.LIST) { + if (originalType != OriginalType.LIST || parquetType.isPrimitive()) { throw new ParquetDecodingException("Expected LIST wrapper. Found: " + originalType + " instead."); } - listOfMessage = fieldDescriptor.getJavaType() == JavaType.MESSAGE; + GroupType rootWrapperType = parquetType.asGroupType(); + if (!rootWrapperType.containsField("list") || rootWrapperType.getType("list").isPrimitive()) { + throw new ParquetDecodingException("Expected repeated 'list' group inside LIST wrapperr but got: " + rootWrapperType); + } - Type parquetSchema; - if (parquetType.asGroupType().containsField("list")) { - parquetSchema = parquetType.asGroupType().getType("list"); - if (parquetSchema.asGroupType().containsField("element")) { - parquetSchema = parquetSchema.asGroupType().getType("element"); - } - } else { - throw new ParquetDecodingException("Expected list but got: " + parquetType); + GroupType listType = rootWrapperType.getType("list").asGroupType(); + if (!listType.containsField("element")) { + throw new ParquetDecodingException("Expected 'element' inside repeated list group but got: " + listType); } - converter = newMessageConverter(parentBuilder, fieldDescriptor, parquetSchema); + Type elementType = listType.getType("element"); + converter = newMessageConverter(parentBuilder, fieldDescriptor, elementType); } @Override diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index eae54ebefa..9c1bf27c93 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -48,6 +48,22 @@ public class ProtoSchemaConverter { private static final Logger LOG = LoggerFactory.getLogger(ProtoSchemaConverter.class); + private final boolean parquetSpecsCompliant; + + public ProtoSchemaConverter() { + this(false); + } + + /** + * Instanciate a schema converter to get the parquet schema corresponding to protobuf classes. + * @param parquetSpecsCompliant If set to false, the parquet schema generated will be using the old + * schema style (prior to PARQUET-968) to provide backward-compatibility + * but which does not use LIST and MAP wrappers around collections as required + * by the parquet specifications. If set to true, specs compliant schemas are used. + */ + public ProtoSchemaConverter(boolean parquetSpecsCompliant) { + this.parquetSpecsCompliant = parquetSpecsCompliant; + } public MessageType convert(Class protobufClass) { LOG.debug("Converting protocol buffer class \"" + protobufClass + "\" to parquet schema."); @@ -86,7 +102,8 @@ private Builder>, GroupBuilder> addF } ParquetType parquetType = getParquetType(descriptor); - if (descriptor.isRepeated()) { + if (descriptor.isRepeated() && parquetSpecsCompliant) { + // the old schema style did not include the LIST wrapper around repeated fields return addRepeatedPrimitive(descriptor, parquetType.primitiveType, parquetType.originalType, builder); } @@ -98,7 +115,7 @@ private Builder>, GroupBuilder> addR OriginalType originalType, final GroupBuilder builder) { return builder - .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) + .group(Type.Repetition.OPTIONAL).as(OriginalType.LIST) .group(Type.Repetition.REPEATED) .primitive(primitiveType, Type.Repetition.REQUIRED).as(originalType) .named("element") @@ -108,7 +125,7 @@ private Builder>, GroupBuilder> addR private GroupBuilder> addRepeatedMessage(FieldDescriptor descriptor, GroupBuilder builder) { GroupBuilder>>> result = builder - .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) + .group(Type.Repetition.OPTIONAL).as(OriginalType.LIST) .group(Type.Repetition.REPEATED) .group(Type.Repetition.OPTIONAL); @@ -118,9 +135,12 @@ private GroupBuilder> addRepeatedMessage(FieldDescriptor des } private GroupBuilder> addMessageField(FieldDescriptor descriptor, final GroupBuilder builder) { - if (descriptor.isMapField()) { + if (descriptor.isMapField() && parquetSpecsCompliant) { + // the old schema style did not include the MAP wrapper around map groups return addMapField(descriptor, builder); - } else if (descriptor.isRepeated()) { + } + if (descriptor.isRepeated() && parquetSpecsCompliant) { + // the old schema style did not include the LIST wrapper around repeated messages return addRepeatedMessage(descriptor, builder); } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index bb75e71748..4d70632061 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -47,7 +47,13 @@ public class ProtoWriteSupport extends WriteSupport< private static final Logger LOG = LoggerFactory.getLogger(ProtoWriteSupport.class); public static final String PB_CLASS_WRITE = "parquet.proto.writeClass"; + // PARQUET-968 introduces changes to allow writing specs compliant schemas with parquet-protobuf. + // In the past, collection were not written using the LIST and MAP wrappers and thus were not compliant + // with the parquet specs. This flag, is set to true, allows to write using spec compliant schemas + // but is set to false by default to keep backward compatibility + public static final String PB_SPECS_COMPLIANT_WRITE = "parquet.proto.writeSpecsCompliant"; + private boolean writeSpecsCompliant = false; private RecordConsumer recordConsumer; private Class protoMessage; private MessageWriter messageWriter; @@ -68,6 +74,16 @@ public static void setSchema(Configuration configuration, Class extraMetaData = new HashMap(); extraMetaData.put(ProtoReadSupport.PB_CLASS, protoMessage.getName()); extraMetaData.put(ProtoReadSupport.PB_DESCRIPTOR, serializeDescriptor(protoMessage)); + extraMetaData.put(PB_SPECS_COMPLIANT_WRITE, String.valueOf(writeSpecsCompliant)); return new WriteContext(rootSchema, extraMetaData); } @@ -158,8 +176,12 @@ class MessageWriter extends FieldWriter { Type type = schema.getType(name); FieldWriter writer = createWriter(fieldDescriptor, type); - if(fieldDescriptor.isRepeated() && !fieldDescriptor.isMapField()) { - writer = new ArrayWriter(writer); + if(writeSpecsCompliant && fieldDescriptor.isRepeated() && !fieldDescriptor.isMapField()) { + writer = new ArrayWriter(writer); + } + else if (!writeSpecsCompliant && fieldDescriptor.isRepeated()) { + // the old schemas style used to write maps as repeated fields instead of wrapping them in a LIST + writer = new RepeatedWriter(writer); } writer.setFieldName(name); @@ -187,7 +209,7 @@ private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) { } private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type type) { - if (fieldDescriptor.isMapField()) { + if (fieldDescriptor.isMapField() && writeSpecsCompliant) { return createMapWriter(fieldDescriptor, type); } @@ -235,16 +257,16 @@ void writeTopLevelMessage(Object value) { /** Writes message as part of repeated field. It cannot start field*/ @Override final void writeRawValue(Object value) { + recordConsumer.startGroup(); writeAllFields((MessageOrBuilder) value); + recordConsumer.endGroup(); } /** Used for writing nonrepeated (optional, required) fields*/ @Override final void writeField(Object value) { recordConsumer.startField(fieldName, index); - recordConsumer.startGroup(); - writeAllFields((MessageOrBuilder) value); - recordConsumer.endGroup(); + writeRawValue(value); recordConsumer.endField(fieldName, index); } @@ -288,21 +310,11 @@ final void writeField(Object value) { recordConsumer.startField("list", 0); // This is the wrapper group for the array field for (Object listEntry: list) { recordConsumer.startGroup(); - recordConsumer.startField("element", 0); // This is the mandatory inner field - if (!isPrimitive(listEntry)) { - recordConsumer.startGroup(); - } - fieldWriter.writeRawValue(listEntry); - if (!isPrimitive(listEntry)) { - recordConsumer.endGroup(); - } - recordConsumer.endField("element", 0); - recordConsumer.endGroup(); } recordConsumer.endField("list", 0); @@ -312,8 +324,33 @@ final void writeField(Object value) { } } - private boolean isPrimitive(Object listEntry) { - return !(listEntry instanceof Message); + /** + * The RepeatedWriter is used to write collections (lists and maps) using the old style (without LIST and MAP + * wrappers). + */ + class RepeatedWriter extends FieldWriter { + final FieldWriter fieldWriter; + + RepeatedWriter(FieldWriter fieldWriter) { + this.fieldWriter = fieldWriter; + } + + @Override + final void writeRawValue(Object value) { + throw new UnsupportedOperationException("Array has no raw value"); + } + + @Override + final void writeField(Object value) { + recordConsumer.startField(fieldName, index); + List list = (List) value; + + for (Object listEntry: list) { + fieldWriter.writeRawValue(listEntry); + } + + recordConsumer.endField(fieldName, index); + } } /** validates mapping between protobuffer fields and parquet fields.*/ @@ -440,5 +477,4 @@ private String serializeDescriptor(Class protoClass) { DescriptorProtos.DescriptorProto asProto = descriptor.toProto(); return TextFormat.printToString(asProto); } - } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java index 6c01d7b8db..5544dc6887 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java @@ -19,6 +19,7 @@ package org.apache.parquet.proto; import com.google.protobuf.Message; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.proto.test.TestProto3; import org.apache.parquet.proto.test.TestProtobuf; @@ -193,6 +194,125 @@ public void testProto3CustomProtoClass() throws Exception { assertEquals("writtenString", stringValue); } + @Test + public void testRepeatedIntMessageClass() throws Exception { + TestProtobuf.RepeatedIntMessage msgEmpty = TestProtobuf.RepeatedIntMessage.newBuilder().build(); + TestProtobuf.RepeatedIntMessage msgNonEmpty = TestProtobuf.RepeatedIntMessage.newBuilder() + .addRepeatedInt(1).addRepeatedInt(2) + .build(); + + Path outputPath = new WriteUsingMR().write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.RepeatedIntMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testRepeatedIntMessageClassSchemaCompliant() throws Exception { + TestProtobuf.RepeatedIntMessage msgEmpty = TestProtobuf.RepeatedIntMessage.newBuilder().build(); + TestProtobuf.RepeatedIntMessage msgNonEmpty = TestProtobuf.RepeatedIntMessage.newBuilder() + .addRepeatedInt(1).addRepeatedInt(2) + .build(); + + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + + Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.RepeatedIntMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testMapIntMessageClass() throws Exception { + TestProtobuf.MapIntMessage msgEmpty = TestProtobuf.MapIntMessage.newBuilder().build(); + TestProtobuf.MapIntMessage msgNonEmpty = TestProtobuf.MapIntMessage.newBuilder() + .putMapInt(1, 123).putMapInt(2, 234) + .build(); + + Path outputPath = new WriteUsingMR().write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.MapIntMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testMapIntMessageClassSchemaCompliant() throws Exception { + TestProtobuf.MapIntMessage msgEmpty = TestProtobuf.MapIntMessage.newBuilder().build(); + TestProtobuf.MapIntMessage msgNonEmpty = TestProtobuf.MapIntMessage.newBuilder() + .putMapInt(1, 123).putMapInt(2, 234) + .build(); + + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + + Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.MapIntMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testRepeatedInnerMessageClass() throws Exception { + TestProtobuf.RepeatedInnerMessage msgEmpty = TestProtobuf.RepeatedInnerMessage.newBuilder().build(); + TestProtobuf.RepeatedInnerMessage msgNonEmpty = TestProtobuf.RepeatedInnerMessage.newBuilder() + .addRepeatedInnerMessage(TestProtobuf.InnerMessage.newBuilder().setOne("one").build()) + .addRepeatedInnerMessage(TestProtobuf.InnerMessage.newBuilder().setTwo("two").build()) + .build(); + + Path outputPath = new WriteUsingMR().write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.RepeatedInnerMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + + @Test + public void testRepeatedInnerMessageClassSchemaCompliant() throws Exception { + TestProtobuf.RepeatedInnerMessage msgEmpty = TestProtobuf.RepeatedInnerMessage.newBuilder().build(); + TestProtobuf.RepeatedInnerMessage msgNonEmpty = TestProtobuf.RepeatedInnerMessage.newBuilder() + .addRepeatedInnerMessage(TestProtobuf.InnerMessage.newBuilder().setOne("one").build()) + .addRepeatedInnerMessage(TestProtobuf.InnerMessage.newBuilder().setTwo("two").build()) + .build(); + + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + + Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); + ReadUsingMR readUsingMR = new ReadUsingMR(); + String customClass = TestProtobuf.RepeatedInnerMessage.class.getName(); + ProtoReadSupport.setProtobufClass(readUsingMR.getConfiguration(), customClass); + List result = readUsingMR.read(outputPath); + + assertEquals(2, result.size()); + assertEquals(msgEmpty, result.get(0)); + assertEquals(msgNonEmpty, result.get(1)); + } + /** * Runs job that writes input to file and then job reading data back. */ diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index d7ec169ce7..4ca82ac740 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -32,14 +32,17 @@ public class ProtoSchemaConverterTest { /** * Converts given pbClass to parquet schema and compares it with expected parquet schema. */ - private void testConversion(Class pbClass, String parquetSchemaString) throws + private void testConversion(Class pbClass, String parquetSchemaString, boolean parquetSpecsCompliant) throws Exception { - ProtoSchemaConverter protoSchemaConverter = new ProtoSchemaConverter(); + ProtoSchemaConverter protoSchemaConverter = new ProtoSchemaConverter(parquetSpecsCompliant); MessageType schema = protoSchemaConverter.convert(pbClass); MessageType expectedMT = MessageTypeParser.parseMessageType(parquetSchemaString); assertEquals(expectedMT.toString(), schema.toString()); } + private void testConversion(Class pbClass, String parquetSchemaString) throws Exception { + testConversion(pbClass, parquetSchemaString, true); + } /** * Tests that all protocol buffer datatypes are converted to correct parquet datatypes. @@ -122,7 +125,7 @@ public void testConvertRepetition() throws Exception { "message TestProtobuf.SchemaConverterRepetition {\n" + " optional int32 optionalPrimitive = 1;\n" + " required int32 requiredPrimitive = 2;\n" + - " required group repeatedPrimitive (LIST) = 3 {\n" + + " optional group repeatedPrimitive (LIST) = 3 {\n" + " repeated group list {\n" + " required int32 element;\n" + " }\n" + @@ -133,7 +136,7 @@ public void testConvertRepetition() throws Exception { " required group requiredMessage = 8 {\n" + " optional int32 someId= 3;\n" + " }\n" + - " required group repeatedMessage (LIST) = 9 {\n" + + " optional group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + " optional group element {\n" + " optional int32 someId = 3;\n" + @@ -150,7 +153,7 @@ public void testProto3ConvertRepetition() throws Exception { String expectedSchema = "message TestProto3.SchemaConverterRepetition {\n" + " optional int32 optionalPrimitive = 1;\n" + - " required group repeatedPrimitive (LIST) = 3 {\n" + + " optional group repeatedPrimitive (LIST) = 3 {\n" + " repeated group list {\n" + " required int32 element;\n" + " }\n" + @@ -158,7 +161,7 @@ public void testProto3ConvertRepetition() throws Exception { " optional group optionalMessage = 7 {\n" + " optional int32 someId = 3;\n" + " }\n" + - " required group repeatedMessage (LIST) = 9 {\n" + + " optional group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + " optional group element {\n" + " optional int32 someId = 3;\n" + @@ -169,4 +172,174 @@ public void testProto3ConvertRepetition() throws Exception { testConversion(TestProto3.SchemaConverterRepetition.class, expectedSchema); } + + @Test + public void testConvertRepeatedIntMessage() throws Exception { + String expectedSchema = + "message TestProtobuf.RepeatedIntMessage {\n" + + " optional group repeatedInt (LIST) = 1 {\n" + + " repeated group list {\n" + + " required int32 element;\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema); + } + + @Test + public void testConvertRepeatedIntMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProtobuf.RepeatedIntMessage {\n" + + " repeated int32 repeatedInt = 1;\n" + + "}"; + + testConversion(TestProtobuf.RepeatedIntMessage.class, expectedSchema, false); + } + + @Test + public void testProto3ConvertRepeatedIntMessage() throws Exception { + String expectedSchema = + "message TestProto3.RepeatedIntMessage {\n" + + " optional group repeatedInt (LIST) = 1 {\n" + + " repeated group list {\n" + + " required int32 element;\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema); + } + + @Test + public void testProto3ConvertRepeatedIntMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProto3.RepeatedIntMessage {\n" + + " repeated int32 repeatedInt = 1;\n" + + "}"; + + testConversion(TestProto3.RepeatedIntMessage.class, expectedSchema, false); + } + + @Test + public void testConvertRepeatedInnerMessage() throws Exception { + String expectedSchema = + "message TestProtobuf.RepeatedInnerMessage {\n" + + " optional group repeatedInnerMessage (LIST) = 1 {\n" + + " repeated group list {\n" + + " optional group element {\n" + + " optional binary one (UTF8) = 1;\n" + + " optional binary two (UTF8) = 2;\n" + + " optional binary three (UTF8) = 3;\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema); + } + + @Test + public void testConvertRepeatedInnerMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProtobuf.RepeatedInnerMessage {\n" + + " repeated group repeatedInnerMessage = 1 {\n" + + " optional binary one (UTF8) = 1;\n" + + " optional binary two (UTF8) = 2;\n" + + " optional binary three (UTF8) = 3;\n" + + " }\n" + + "}"; + + testConversion(TestProtobuf.RepeatedInnerMessage.class, expectedSchema, false); + } + + @Test + public void testProto3ConvertRepeatedInnerMessage() throws Exception { + String expectedSchema = + "message TestProto3.RepeatedInnerMessage {\n" + + " optional group repeatedInnerMessage (LIST) = 1 {\n" + + " repeated group list {\n" + + " optional group element {\n" + + " optional binary one (UTF8) = 1;\n" + + " optional binary two (UTF8) = 2;\n" + + " optional binary three (UTF8) = 3;\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema); + } + + @Test + public void testProto3ConvertRepeatedInnerMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProto3.RepeatedInnerMessage {\n" + + " repeated group repeatedInnerMessage = 1 {\n" + + " optional binary one (UTF8) = 1;\n" + + " optional binary two (UTF8) = 2;\n" + + " optional binary three (UTF8) = 3;\n" + + " }\n" + + "}"; + + testConversion(TestProto3.RepeatedInnerMessage.class, expectedSchema, false); + } + + @Test + public void testConvertMapIntMessage() throws Exception { + String expectedSchema = + "message TestProtobuf.MapIntMessage {\n" + + " optional group mapInt (MAP) = 1 {\n" + + " repeated group key_value {\n" + + " required int32 key;\n" + + " optional int32 value;\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProtobuf.MapIntMessage.class, expectedSchema); + } + + @Test + public void testConvertMapIntMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProtobuf.MapIntMessage {\n" + + " repeated group mapInt = 1 {\n" + + " optional int32 key = 1;\n" + + " optional int32 value = 2;\n" + + " }\n" + + "}"; + + testConversion(TestProtobuf.MapIntMessage.class, expectedSchema, false); + } + + @Test + public void testProto3ConvertMapIntMessage() throws Exception { + String expectedSchema = + "message TestProto3.MapIntMessage {\n" + + " optional group mapInt (MAP) = 1 {\n" + + " repeated group key_value {\n" + + " required int32 key;\n" + + " optional int32 value;\n" + + " }\n" + + " }\n" + + "}"; + + testConversion(TestProto3.MapIntMessage.class, expectedSchema); + } + + @Test + public void testProto3ConvertMapIntMessageNonSpecsCompliant() throws Exception { + String expectedSchema = + "message TestProto3.MapIntMessage {\n" + + " repeated group mapInt = 1 {\n" + + " optional int32 key = 1;\n" + + " optional int32 value = 2;\n" + + " }\n" + + "}"; + + testConversion(TestProto3.MapIntMessage.class, expectedSchema, false); + } } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index de27ebf3f4..f71229c222 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -31,8 +31,12 @@ public class ProtoWriteSupportTest { private ProtoWriteSupport createReadConsumerInstance(Class cls, RecordConsumer readConsumerMock) { + return createReadConsumerInstance(cls, readConsumerMock, new Configuration()); + } + + private ProtoWriteSupport createReadConsumerInstance(Class cls, RecordConsumer readConsumerMock, Configuration conf) { ProtoWriteSupport support = new ProtoWriteSupport(cls); - support.init(new Configuration()); + support.init(conf); support.prepareForWrite(readConsumerMock); return support; } @@ -80,9 +84,11 @@ public void testProto3SimplestMessage() throws Exception { } @Test - public void testRepeatedIntMessage() throws Exception { + public void testRepeatedIntMessageSpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock, conf); TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); msg.addRepeatedInt(1323); @@ -117,9 +123,67 @@ public void testRepeatedIntMessage() throws Exception { } @Test - public void testProto3RepeatedIntMessage() throws Exception { + public void testRepeatedIntMessage() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.RepeatedIntMessage.class, readConsumerMock); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock); + + TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); + msg.addRepeatedInt(1323); + msg.addRepeatedInt(54469); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("repeatedInt", 0); + inOrder.verify(readConsumerMock).addInteger(1323); + inOrder.verify(readConsumerMock).addInteger(54469); + inOrder.verify(readConsumerMock).endField("repeatedInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testRepeatedIntMessageEmptySpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock, conf); + + TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testRepeatedIntMessageEmpty() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock); + + TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3RepeatedIntMessageSpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.RepeatedIntMessage.class, readConsumerMock, conf); TestProto3.RepeatedIntMessage.Builder msg = TestProto3.RepeatedIntMessage.newBuilder(); msg.addRepeatedInt(1323); @@ -153,6 +217,290 @@ public void testProto3RepeatedIntMessage() throws Exception { Mockito.verifyNoMoreInteractions(readConsumerMock); } + @Test + public void testProto3RepeatedIntMessage() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.RepeatedIntMessage.class, readConsumerMock); + + TestProto3.RepeatedIntMessage.Builder msg = TestProto3.RepeatedIntMessage.newBuilder(); + msg.addRepeatedInt(1323); + msg.addRepeatedInt(54469); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("repeatedInt", 0); + inOrder.verify(readConsumerMock).addInteger(1323); + inOrder.verify(readConsumerMock).addInteger(54469); + inOrder.verify(readConsumerMock).endField("repeatedInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3RepeatedIntMessageEmptySpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock, conf); + + TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3RepeatedIntMessageEmpty() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock); + + TestProtobuf.RepeatedIntMessage.Builder msg = TestProtobuf.RepeatedIntMessage.newBuilder(); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testMapIntMessageSpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock, conf); + + TestProtobuf.MapIntMessage.Builder msg = TestProtobuf.MapIntMessage.newBuilder(); + msg.putMapInt(123, 1); + msg.putMapInt(234, 2); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("mapInt", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key_value", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(123); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(1); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(234); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(2); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("key_value", 0); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("mapInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testMapIntMessage() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock); + + TestProtobuf.MapIntMessage.Builder msg = TestProtobuf.MapIntMessage.newBuilder(); + msg.putMapInt(123, 1); + msg.putMapInt(234, 2); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("mapInt", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(123); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(1); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(234); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(2); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("mapInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testMapIntMessageEmptySpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock, conf); + + TestProtobuf.MapIntMessage.Builder msg = TestProtobuf.MapIntMessage.newBuilder(); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testMapIntMessageEmpty() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock); + + TestProtobuf.MapIntMessage.Builder msg = TestProtobuf.MapIntMessage.newBuilder(); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3MapIntMessageSpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock, conf); + + TestProto3.MapIntMessage.Builder msg = TestProto3.MapIntMessage.newBuilder(); + msg.putMapInt(123, 1); + msg.putMapInt(234, 2); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("mapInt", 0); + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key_value", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(123); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(1); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(234); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(2); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("key_value", 0); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("mapInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3MapIntMessage() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock); + + TestProto3.MapIntMessage.Builder msg = TestProto3.MapIntMessage.newBuilder(); + msg.putMapInt(123, 1); + msg.putMapInt(234, 2); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("mapInt", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(123); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(1); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("key", 0); + inOrder.verify(readConsumerMock).addInteger(234); + inOrder.verify(readConsumerMock).endField("key", 0); + inOrder.verify(readConsumerMock).startField("value", 1); + inOrder.verify(readConsumerMock).addInteger(2); + inOrder.verify(readConsumerMock).endField("value", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("mapInt", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3MapIntMessageEmptySpecsCompliant() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock, conf); + + TestProto3.MapIntMessage.Builder msg = TestProto3.MapIntMessage.newBuilder(); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3MapIntMessageEmpty() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock); + + TestProto3.MapIntMessage.Builder msg = TestProto3.MapIntMessage.newBuilder(); + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + @Test public void testRepeatedInnerMessageMessage_message() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); @@ -165,6 +513,37 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("inner", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("one", 0); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); + inOrder.verify(readConsumerMock).endField("one", 0); + inOrder.verify(readConsumerMock).startField("two", 1); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); + inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("inner", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testRepeatedInnerMessageSpecsCompliantMessage_message() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock, conf); + + TestProtobuf.TopMessage.Builder msg = TestProtobuf.TopMessage.newBuilder(); + msg.addInnerBuilder().setOne("one").setTwo("two"); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); @@ -192,7 +571,7 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { @Test public void testProto3RepeatedInnerMessageMessage_message() throws Exception { - RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class);; ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.TopMessage.class, readConsumerMock); TestProto3.TopMessage.Builder msg = TestProto3.TopMessage.newBuilder(); @@ -202,6 +581,37 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("inner", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("one", 0); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); + inOrder.verify(readConsumerMock).endField("one", 0); + inOrder.verify(readConsumerMock).startField("two", 1); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); + inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("inner", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3RepeatedInnerMessageSpecsCompliantMessage_message() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.TopMessage.class, readConsumerMock, conf); + + TestProto3.TopMessage.Builder msg = TestProto3.TopMessage.newBuilder(); + msg.addInnerBuilder().setOne("one").setTwo("two"); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); @@ -227,10 +637,13 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { Mockito.verifyNoMoreInteractions(readConsumerMock); } + @Test - public void testRepeatedInnerMessageMessage_scalar() throws Exception { + public void testRepeatedInnerMessageSpecsCompliantMessage_scalar() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock, conf); TestProtobuf.TopMessage.Builder msg = TestProtobuf.TopMessage.newBuilder(); msg.addInnerBuilder().setOne("one"); @@ -274,6 +687,41 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { Mockito.verifyNoMoreInteractions(readConsumerMock); } + @Test + public void testRepeatedInnerMessageMessage_scalar() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock); + + TestProtobuf.TopMessage.Builder msg = TestProtobuf.TopMessage.newBuilder(); + msg.addInnerBuilder().setOne("one"); + msg.addInnerBuilder().setTwo("two"); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("inner", 0); + + //first inner message + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("one", 0); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); + inOrder.verify(readConsumerMock).endField("one", 0); + inOrder.verify(readConsumerMock).endGroup(); + + //second inner message + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("two", 1); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); + inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("inner", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + @Test public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); @@ -287,6 +735,43 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); + inOrder.verify(readConsumerMock).startField("inner", 0); + + //first inner message + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("one", 0); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); + inOrder.verify(readConsumerMock).endField("one", 0); + inOrder.verify(readConsumerMock).endGroup(); + + //second inner message + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("two", 1); + inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); + inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + + inOrder.verify(readConsumerMock).endField("inner", 0); + inOrder.verify(readConsumerMock).endMessage(); + Mockito.verifyNoMoreInteractions(readConsumerMock); + } + + @Test + public void testProto3RepeatedInnerMessageSpecsCompliantMessage_scalar() throws Exception { + RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); + Configuration conf = new Configuration(); + ProtoWriteSupport.setWriteSpecsCompliant(conf, true); + ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.TopMessage.class, readConsumerMock, conf); + + TestProto3.TopMessage.Builder msg = TestProto3.TopMessage.newBuilder(); + msg.addInnerBuilder().setOne("one"); + msg.addInnerBuilder().setTwo("two"); + + instance.write(msg.build()); + + InOrder inOrder = Mockito.inOrder(readConsumerMock); + inOrder.verify(readConsumerMock).startMessage(); inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/utils/WriteUsingMR.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/utils/WriteUsingMR.java index d18076a642..55f9237ec5 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/utils/WriteUsingMR.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/utils/WriteUsingMR.java @@ -46,10 +46,18 @@ public class WriteUsingMR { private static final Logger LOG = LoggerFactory.getLogger(WriteUsingMR.class); - Configuration conf = new Configuration(); + private final Configuration conf; private static List inputMessages; Path outputPath; + public WriteUsingMR() { + this(new Configuration()); + } + + public WriteUsingMR(Configuration conf) { + this.conf = new Configuration(); + } + public Configuration getConfiguration() { return conf; } diff --git a/parquet-protobuf/src/test/resources/TestProto3.proto b/parquet-protobuf/src/test/resources/TestProto3.proto index 1896445306..e49eef5727 100644 --- a/parquet-protobuf/src/test/resources/TestProto3.proto +++ b/parquet-protobuf/src/test/resources/TestProto3.proto @@ -124,6 +124,14 @@ message RepeatedIntMessage { repeated int32 repeatedInt = 1; } +message RepeatedInnerMessage { + repeated InnerMessage repeatedInnerMessage = 1; +} + +message MapIntMessage { + map mapInt = 1; +} + message HighIndexMessage { repeated int32 repeatedInt = 50000; } diff --git a/parquet-protobuf/src/test/resources/TestProtobuf.proto b/parquet-protobuf/src/test/resources/TestProtobuf.proto index d7cdf03a91..d4ab4c7dcd 100644 --- a/parquet-protobuf/src/test/resources/TestProtobuf.proto +++ b/parquet-protobuf/src/test/resources/TestProtobuf.proto @@ -122,6 +122,14 @@ message RepeatedIntMessage { repeated int32 repeatedInt = 1; } +message RepeatedInnerMessage { + repeated InnerMessage repeatedInnerMessage = 1; +} + +message MapIntMessage { + map mapInt = 1; +} + message HighIndexMessage { repeated int32 repeatedInt = 50000; }