diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index e478652207..e8ace746a4 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -374,7 +374,6 @@ public void addBinary(Binary binary) { */ final class ListConverter extends GroupConverter { private final Converter converter; - private final boolean listOfMessage; public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { OriginalType originalType = parquetType.getOriginalType(); @@ -382,13 +381,11 @@ public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor throw new ParquetDecodingException("Expected LIST wrapper. Found: " + originalType + " instead."); } - listOfMessage = fieldDescriptor.getJavaType() == JavaType.MESSAGE; - Type parquetSchema; if (parquetType.asGroupType().containsField("list")) { parquetSchema = parquetType.asGroupType().getType("list"); if (parquetSchema.asGroupType().containsField("element")) { - parquetSchema.asGroupType().getType("element"); + parquetSchema = parquetSchema.asGroupType().getType("element"); } } else { throw new ParquetDecodingException("Expected list but got: " + parquetType); @@ -403,10 +400,6 @@ public Converter getConverter(int fieldIndex) { throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); } - if (listOfMessage) { - return converter; - } - return new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { @@ -447,10 +440,10 @@ public MapConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor f } Type parquetSchema; - if (parquetType.asGroupType().containsField("map")){ - parquetSchema = parquetType.asGroupType().getType("map"); + if (parquetType.asGroupType().containsField("key_value")){ + parquetSchema = parquetType.asGroupType().getType("key_value"); } else { - throw new ParquetDecodingException("Expected map but got: " + parquetType); + throw new ParquetDecodingException("Expected key_value but got: " + parquetType); } converter = newMessageConverter(parentBuilder, fieldDescriptor, parquetSchema); diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java index a5b4edebe5..d412362ba8 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoSchemaConverter.java @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,6 +19,7 @@ package org.apache.parquet.proto; import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; import com.google.protobuf.Message; import com.twitter.elephantbird.util.Protobufs; @@ -26,6 +27,7 @@ import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; import org.apache.parquet.schema.Types.Builder; import org.apache.parquet.schema.Types.GroupBuilder; @@ -59,7 +61,7 @@ public MessageType convert(Class protobufClass) { } /* Iterates over list of fields. **/ - private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { + private GroupBuilder convertFields(GroupBuilder groupBuilder, List fieldDescriptors) { for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) { groupBuilder = addField(fieldDescriptor, groupBuilder) @@ -105,14 +107,15 @@ private Builder>, GroupBuilder> addR } private GroupBuilder> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder builder) { - GroupBuilder>> result = + GroupBuilder>>> result = builder .group(Type.Repetition.REQUIRED).as(OriginalType.LIST) - .group(Type.Repetition.REPEATED); + .group(Type.Repetition.REPEATED) + .group(Type.Repetition.OPTIONAL); convertFields(result, descriptor.getMessageType().getFields()); - return result.named("list"); + return result.named("element").named("list"); } private GroupBuilder> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder builder) { @@ -137,12 +140,12 @@ private GroupBuilder> addMapField(Descriptors.FieldDescripto ParquetType mapKeyParquetType = getParquetType(fields.get(0)); GroupBuilder>> group = builder - .group(getRepetition(descriptor)).as(OriginalType.MAP) - .group(Type.Repetition.REPEATED).as(OriginalType.MAP_KEY_VALUE) + .group(Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3 + .group(Type.Repetition.REPEATED) .primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key"); return addField(fields.get(1), group).named("value") - .named("map"); + .named("key_value"); } private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 31e386daee..75809d4b53 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,11 @@ import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; -import com.google.protobuf.MapEntry; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.TextFormat; @@ -203,11 +207,11 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip private GroupType getGroupType(Type type) { if (type.getOriginalType() == OriginalType.LIST) { - return type.asGroupType().getType("list").asGroupType(); + return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType(); } if (type.getOriginalType() == OriginalType.MAP) { - return type.asGroupType().getType("map").asGroupType().getType("value").asGroupType(); + return type.asGroupType().getType("key_value").asGroupType().getType("value").asGroupType(); } return type.asGroupType(); @@ -295,13 +299,21 @@ final void writeField(Object value) { recordConsumer.startField("list", 0); // This is the wrapper group for the array field for (Object listEntry: list) { recordConsumer.startGroup(); - if (isPrimitive(listEntry)) { - recordConsumer.startField("element", 0); + + recordConsumer.startField("element", 0); // This is the mandatory inner field + + if (!isPrimitive(listEntry)) { + recordConsumer.startGroup(); } + fieldWriter.writeRawValue(listEntry); - if (isPrimitive(listEntry)) { - recordConsumer.endField("element", 0); + + if (!isPrimitive(listEntry)) { + recordConsumer.endGroup(); } + + recordConsumer.endField("element", 0); + recordConsumer.endGroup(); } recordConsumer.endField("list", 0); @@ -369,15 +381,21 @@ public MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) { final void writeRawValue(Object value) { recordConsumer.startGroup(); - recordConsumer.startField("map", 0); // This is the wrapper group for the map field - for(MapEntry entry : (Collection>) value) { + recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field + for (Message msg : (Collection) value) { recordConsumer.startGroup(); - keyWriter.writeField(entry.getKey()); - valueWriter.writeField(entry.getValue()); + + final Descriptor descriptorForType = msg.getDescriptorForType(); + final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key"); + final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value"); + + keyWriter.writeField(msg.getField(keyDesc)); + valueWriter.writeField(msg.getField(valueDesc)); + recordConsumer.endGroup(); } - recordConsumer.endField("map", 0); + recordConsumer.endField("key_value", 0); recordConsumer.endGroup(); } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 34f2f23f2f..d7ec169ce7 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -103,8 +103,8 @@ public void testProto3ConvertAllDatatypes() throws Exception { " optional binary optionalEnum (ENUM) = 18;" + " optional int32 someInt32 = 19;" + " optional binary someString (UTF8) = 20;" + - " repeated group optionalMap (MAP) = 21 {\n" + - " repeated group map (MAP_KEY_VALUE) {\n" + + " optional group optionalMap (MAP) = 21 {\n" + + " repeated group key_value {\n" + " required int64 key;\n" + " optional group value {\n" + " optional int32 someId = 3;\n" + @@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }" + "}"; @@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception { " }\n" + " required group repeatedMessage (LIST) = 9 {\n" + " repeated group list {\n" + - " optional int32 someId = 3;\n" + + " optional group element {\n" + + " optional int32 someId = 3;\n" + + " }\n" + " }\n" + " }\n" + "}"; diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index e00facfc06..de27ebf3f4 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -169,6 +169,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startField("inner", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); + + inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); @@ -177,6 +180,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("inner", 0); @@ -201,12 +207,18 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("list", 0); inOrder.verify(readConsumerMock).startGroup(); + + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); + inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -235,17 +247,25 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup(); @@ -274,17 +294,25 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { //first inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("one", 0); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes())); inOrder.verify(readConsumerMock).endField("one", 0); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); //second inner message inOrder.verify(readConsumerMock).startGroup(); + inOrder.verify(readConsumerMock).startField("element", 0); + inOrder.verify(readConsumerMock).startGroup(); inOrder.verify(readConsumerMock).startField("two", 1); inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes())); inOrder.verify(readConsumerMock).endField("two", 1); inOrder.verify(readConsumerMock).endGroup(); + inOrder.verify(readConsumerMock).endField("element", 0); + inOrder.verify(readConsumerMock).endGroup(); inOrder.verify(readConsumerMock).endField("list", 0); inOrder.verify(readConsumerMock).endGroup();