From a8bd704138e9a3d99d96752766b80b5fb576802e Mon Sep 17 00:00:00 2001 From: Constantin Muraru Date: Wed, 6 Sep 2017 00:06:11 +0300 Subject: [PATCH] Pick up commit from @andredasilvapinto https://github.com/andredasilvapinto/parquet-mr/commit/dfa9701a4d843bb7cd1d429d86d17811b735f33c --- .../parquet/proto/ProtoMessageConverter.java | 6 +- .../parquet/proto/ProtoSchemaConverter.java | 32 ++++---- .../parquet/proto/ProtoWriteSupport.java | 81 ++++++++++--------- .../proto/ProtoSchemaConverterTest.java | 10 ++- .../parquet/proto/ProtoWriteSupportTest.java | 28 +++++++ 5 files changed, 97 insertions(+), 60 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 953994f1c1..bb9930b6fe 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -388,7 +388,7 @@ public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor if (parquetType.asGroupType().containsField("list")) { parquetSchema = parquetType.asGroupType().getType("list"); if (parquetSchema.asGroupType().containsField("element")) { - parquetSchema.asGroupType().getType("element"); + parquetSchema = parquetSchema.asGroupType().getType("element"); } } else { throw new ParquetDecodingException("Expected list but got: " + parquetType); @@ -403,10 +403,6 @@ public Converter getConverter(int fieldIndex) { throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); } - if (listOfMessage) { - return converter; - } - return new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index f3dd11db38..eae54ebefa 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -19,6 +19,7 @@ package org.apache.parquet.proto; import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; import com.google.protobuf.Message; import com.twitter.elephantbird.util.Protobufs; @@ -59,8 +60,8 @@ public MessageType convert(Class protobufClass) { } /* Iterates over list of fields. **/ - private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { - for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) { + private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { + for (FieldDescriptor fieldDescriptor : fieldDescriptors) { groupBuilder = addField(fieldDescriptor, groupBuilder) .id(fieldDescriptor.getNumber()) @@ -69,7 +70,7 @@ private GroupBuilder convertFields(GroupBuilder groupBuilder, List Builder>, GroupBuilder> addField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + private Builder>, GroupBuilder> addField(FieldDescriptor descriptor, final GroupBuilder builder) { if (descriptor.getJavaType() == JavaType.MESSAGE) { return addMessageField(descriptor, builder); } @@ -92,7 +93,7 @@ private Builder>, GroupBuilder> addF return builder.primitive(parquetType.primitiveType, getRepetition(descriptor)).as(parquetType.originalType); } - private Builder>, GroupBuilder> addRepeatedPrimitive(Descriptors.FieldDescriptor descriptor, + private Builder>, GroupBuilder> addRepeatedPrimitive(FieldDescriptor descriptor, PrimitiveTypeName primitiveType, OriginalType originalType, final GroupBuilder builder) { @@ -104,18 +105,19 @@ private Builder>, GroupBuilder> addR .named("list"); } - private GroupBuilder> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder builder) { - GroupBuilder>> result = + private GroupBuilder> addRepeatedMessage(FieldDescriptor descriptor, GroupBuilder builder) { + GroupBuilder>>> result = builder .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) - .group(Type.Repetition.REPEATED); + .group(Type.Repetition.REPEATED) + .group(Type.Repetition.OPTIONAL); convertFields(result, descriptor.getMessageType().getFields()); - return result.named("list"); + return result.named("element").named("list"); } - private GroupBuilder> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { + private GroupBuilder> addMessageField(FieldDescriptor descriptor, final GroupBuilder builder) { if (descriptor.isMapField()) { return addMapField(descriptor, builder); } else if (descriptor.isRepeated()) { @@ -128,8 +130,8 @@ private GroupBuilder> addMessageField(Descriptors.FieldDescr return group; } - private GroupBuilder> addMapField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { - List fields = descriptor.getMessageType().getFields(); + private GroupBuilder> addMapField(FieldDescriptor descriptor, final GroupBuilder builder) { + List fields = descriptor.getMessageType().getFields(); if (fields.size() != 2) { throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); } @@ -137,7 +139,7 @@ private GroupBuilder> addMapField(Descriptors.FieldDescripto ParquetType mapKeyParquetType = getParquetType(fields.get(0)); GroupBuilder>> group = builder - .group(Type.Repetition.REQUIRED).as(OriginalType.MAP) + .group(Type.Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3 .group(Type.Repetition.REPEATED) // key_value wrapper .primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key"); @@ -145,7 +147,7 @@ private GroupBuilder> addMapField(Descriptors.FieldDescripto .named("key_value"); } - private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) { + private ParquetType getParquetType(FieldDescriptor fieldDescriptor) { JavaType javaType = fieldDescriptor.getJavaType(); switch (javaType) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 8e2b4aeb44..bb75e71748 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -18,13 +18,9 @@ */ package org.apache.parquet.proto; -import com.google.protobuf.ByteString; -import com.google.protobuf.DescriptorProtos; -import com.google.protobuf.Descriptors; -import com.google.protobuf.MapEntry; -import com.google.protobuf.Message; -import com.google.protobuf.MessageOrBuilder; -import com.google.protobuf.TextFormat; +import com.google.protobuf.*; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; import com.twitter.elephantbird.util.Protobufs; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.hadoop.BadConfigurationException; @@ -32,10 +28,7 @@ import org.apache.parquet.io.InvalidRecordException; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.RecordConsumer; -import org.apache.parquet.schema.GroupType; -import org.apache.parquet.schema.IncompatibleSchemaModificationException; -import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.*; import org.apache.parquet.schema.Type; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,7 +106,7 @@ public WriteContext init(Configuration configuration) { } MessageType rootSchema = new ProtoSchemaConverter().convert(protoMessage); - Descriptors.Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage); + Descriptor messageDescriptor = Protobufs.getMessageDescriptor(protoMessage); validatedMapping(messageDescriptor, rootSchema); this.messageWriter = new MessageWriter(messageDescriptor, rootSchema); @@ -156,11 +149,11 @@ class MessageWriter extends FieldWriter { final FieldWriter[] fieldWriters; @SuppressWarnings("unchecked") - MessageWriter(Descriptors.Descriptor descriptor, GroupType schema) { - List fields = descriptor.getFields(); + MessageWriter(Descriptor descriptor, GroupType schema) { + List fields = descriptor.getFields(); fieldWriters = (FieldWriter[]) Array.newInstance(FieldWriter.class, fields.size()); - for (Descriptors.FieldDescriptor fieldDescriptor: fields) { + for (FieldDescriptor fieldDescriptor: fields) { String name = fieldDescriptor.getName(); Type type = schema.getType(name); FieldWriter writer = createWriter(fieldDescriptor, type); @@ -176,7 +169,7 @@ class MessageWriter extends FieldWriter { } } - private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) { switch (fieldDescriptor.getJavaType()) { case STRING: return new StringWriter() ; @@ -193,7 +186,7 @@ private FieldWriter createWriter(Descriptors.FieldDescriptor fieldDescriptor, Ty return unknownType(fieldDescriptor);//should not be executed, always throws exception. } - private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { + private FieldWriter createMessageWriter(FieldDescriptor fieldDescriptor, Type type) { if (fieldDescriptor.isMapField()) { return createMapWriter(fieldDescriptor, type); } @@ -203,7 +196,7 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip private GroupType getGroupType(Type type) { if (type.getOriginalType() == OriginalType.LIST) { - return type.asGroupType().getType("list").asGroupType(); + return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType(); } if (type.getOriginalType() == OriginalType.MAP) { @@ -213,20 +206,20 @@ private GroupType getGroupType(Type type) { return type.asGroupType(); } - private MapWriter createMapWriter(Descriptors.FieldDescriptor fieldDescriptor, Type type) { - List fields = fieldDescriptor.getMessageType().getFields(); + private MapWriter createMapWriter(FieldDescriptor fieldDescriptor, Type type) { + List fields = fieldDescriptor.getMessageType().getFields(); if (fields.size() != 2) { throw new UnsupportedOperationException("Expected two fields for the map (key/value), but got: " + fields); } // KeyFieldWriter - Descriptors.FieldDescriptor keyProtoField = fields.get(0); + FieldDescriptor keyProtoField = fields.get(0); FieldWriter keyWriter = createWriter(keyProtoField, type); keyWriter.setFieldName(keyProtoField.getName()); keyWriter.setIndex(0); // ValueFieldWriter - Descriptors.FieldDescriptor valueProtoField = fields.get(1); + FieldDescriptor valueProtoField = fields.get(1); FieldWriter valueWriter = createWriter(valueProtoField, type); valueWriter.setFieldName(valueProtoField.getName()); valueWriter.setIndex(1); @@ -257,10 +250,10 @@ final void writeField(Object value) { private void writeAllFields(MessageOrBuilder pb) { //returns changed fields with values. Map is ordered by id. - Map changedPbFields = pb.getAllFields(); + Map changedPbFields = pb.getAllFields(); - for (Map.Entry entry : changedPbFields.entrySet()) { - Descriptors.FieldDescriptor fieldDescriptor = entry.getKey(); + for (Map.Entry entry : changedPbFields.entrySet()) { + FieldDescriptor fieldDescriptor = entry.getKey(); if(fieldDescriptor.isExtension()) { // Field index of an extension field might overlap with a base field. @@ -295,13 +288,21 @@ final void writeField(Object value) { recordConsumer.startField("list", 0); // This is the wrapper group for the array field for (Object listEntry: list) { recordConsumer.startGroup(); - if (isPrimitive(listEntry)) { - recordConsumer.startField("element", 0); + + recordConsumer.startField("element", 0); // This is the mandatory inner field + + if (!isPrimitive(listEntry)) { + recordConsumer.startGroup(); } + fieldWriter.writeRawValue(listEntry); - if (isPrimitive(listEntry)) { - recordConsumer.endField("element", 0); + + if (!isPrimitive(listEntry)) { + recordConsumer.endGroup(); } + + recordConsumer.endField("element", 0); + recordConsumer.endGroup(); } recordConsumer.endField("list", 0); @@ -316,10 +317,10 @@ private boolean isPrimitive(Object listEntry) { } /** validates mapping between protobuffer fields and parquet fields.*/ - private void validatedMapping(Descriptors.Descriptor descriptor, GroupType parquetSchema) { - List allFields = descriptor.getFields(); + private void validatedMapping(Descriptor descriptor, GroupType parquetSchema) { + List allFields = descriptor.getFields(); - for (Descriptors.FieldDescriptor fieldDescriptor: allFields) { + for (FieldDescriptor fieldDescriptor: allFields) { String fieldName = fieldDescriptor.getName(); int fieldIndex = fieldDescriptor.getIndex(); int parquetIndex = parquetSchema.getFieldIndex(fieldName); @@ -370,10 +371,16 @@ final void writeRawValue(Object value) { recordConsumer.startGroup(); recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field - for(MapEntry entry : (Collection>) value) { + for (Message msg : (Collection) value) { recordConsumer.startGroup(); - keyWriter.writeField(entry.getKey()); - valueWriter.writeField(entry.getValue()); + + final Descriptor descriptorForType = msg.getDescriptorForType(); + final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key"); + final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value"); + + keyWriter.writeField(msg.getField(keyDesc)); + valueWriter.writeField(msg.getField(valueDesc)); + recordConsumer.endGroup(); } @@ -421,7 +428,7 @@ final void writeRawValue(Object value) { } } - private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) { + private FieldWriter unknownType(FieldDescriptor fieldDescriptor) { String exceptionMsg = "Unknown type with descriptor \"" + fieldDescriptor + "\" and type \"" + fieldDescriptor.getJavaType() + "\"."; throw new InvalidRecordException(exceptionMsg); @@ -429,7 +436,7 @@ private FieldWriter unknownType(Descriptors.FieldDescriptor fieldDescriptor) { /** Returns message descriptor as JSON String*/ private String serializeDescriptor(Class protoClass) { - Descriptors.Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass); + Descriptor descriptor = Protobufs.getMessageDescriptor(protoClass); DescriptorProtos.DescriptorProto asProto = descriptor.toProto(); return TextFormat.printToString(asProto); } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 70bc1f79d9..d7ec169ce7 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -103,7 +103,7 @@ public void testProto3ConvertAllDatatypes() throws Exception { " optional binary optionalEnum (ENUM) = 18;" + " optional int32 someInt32 = 19;" + " optional binary someString (UTF8) = 20;" + - " required group optionalMap (MAP) = 21 {\n" + + " optional group optionalMap (MAP) = 21 {\n" + " repeated group key_value {\n" + " required int64 key;\n" + " optional group value {\n" + @@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }" + "}"; @@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }\n" + "}"; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index e00facfc06..de27ebf3f4 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -169,6 +169,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); @@ -177,6 +180,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); @@ -201,12 +207,18 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); inOrder.verify(readConsumerMock).startGroup(); + + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -235,17 +247,25 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -274,17 +294,25 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup();