Skip to content

Commit

Permalink
Fix Proto2Parquet conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
andredasilvapinto committed Sep 5, 2017
1 parent 121c0b7 commit dfa9701
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -374,21 +374,18 @@ public void addBinary(Binary binary) {
*/
final class ListConverter extends GroupConverter {
private final Converter converter;
private final boolean listOfMessage;

public ListConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) {
OriginalType originalType = parquetType.getOriginalType();
if (originalType != OriginalType.LIST) {
throw new ParquetDecodingException("Expected LIST wrapper. Found: " + originalType + " instead.");
}

listOfMessage = fieldDescriptor.getJavaType() == JavaType.MESSAGE;

Type parquetSchema;
if (parquetType.asGroupType().containsField("list")) {
parquetSchema = parquetType.asGroupType().getType("list");
if (parquetSchema.asGroupType().containsField("element")) {
parquetSchema.asGroupType().getType("element");
parquetSchema = parquetSchema.asGroupType().getType("element");
}
} else {
throw new ParquetDecodingException("Expected list but got: " + parquetType);
Expand All @@ -403,10 +400,6 @@ public Converter getConverter(int fieldIndex) {
throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper");
}

if (listOfMessage) {
return converter;
}

return new GroupConverter() {
@Override
public Converter getConverter(int fieldIndex) {
Expand Down Expand Up @@ -447,10 +440,10 @@ public MapConverter(Message.Builder parentBuilder, Descriptors.FieldDescriptor f
}

Type parquetSchema;
if (parquetType.asGroupType().containsField("map")){
parquetSchema = parquetType.asGroupType().getType("map");
if (parquetType.asGroupType().containsField("key_value")){
parquetSchema = parquetType.asGroupType().getType("key_value");
} else {
throw new ParquetDecodingException("Expected map but got: " + parquetType);
throw new ParquetDecodingException("Expected key_value but got: " + parquetType);
}

converter = newMessageConverter(parentBuilder, fieldDescriptor, parquetSchema);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -19,13 +19,15 @@
package org.apache.parquet.proto;

import com.google.protobuf.Descriptors;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType;
import com.google.protobuf.Message;
import com.twitter.elephantbird.util.Protobufs;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName;
import org.apache.parquet.schema.Type;
import org.apache.parquet.schema.Type.Repetition;
import org.apache.parquet.schema.Types;
import org.apache.parquet.schema.Types.Builder;
import org.apache.parquet.schema.Types.GroupBuilder;
Expand Down Expand Up @@ -59,7 +61,7 @@ public MessageType convert(Class<? extends Message> protobufClass) {
}

/* Iterates over list of fields. **/
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<Descriptors.FieldDescriptor> fieldDescriptors) {
private <T> GroupBuilder<T> convertFields(GroupBuilder<T> groupBuilder, List<FieldDescriptor> fieldDescriptors) {
for (Descriptors.FieldDescriptor fieldDescriptor : fieldDescriptors) {
groupBuilder =
addField(fieldDescriptor, groupBuilder)
Expand Down Expand Up @@ -105,14 +107,15 @@ private <T> Builder<? extends Builder<?, GroupBuilder<T>>, GroupBuilder<T>> addR
}

private <T> GroupBuilder<GroupBuilder<T>> addRepeatedMessage(Descriptors.FieldDescriptor descriptor, GroupBuilder<T> builder) {
GroupBuilder<GroupBuilder<GroupBuilder<T>>> result =
GroupBuilder<GroupBuilder<GroupBuilder<GroupBuilder<T>>>> result =
builder
.group(Type.Repetition.REQUIRED).as(OriginalType.LIST)
.group(Type.Repetition.REPEATED);
.group(Type.Repetition.REPEATED)
.group(Type.Repetition.OPTIONAL);

convertFields(result, descriptor.getMessageType().getFields());

return result.named("list");
return result.named("element").named("list");
}

private <T> GroupBuilder<GroupBuilder<T>> addMessageField(Descriptors.FieldDescriptor descriptor, final GroupBuilder<T> builder) {
Expand All @@ -137,12 +140,12 @@ private <T> GroupBuilder<GroupBuilder<T>> addMapField(Descriptors.FieldDescripto
ParquetType mapKeyParquetType = getParquetType(fields.get(0));

GroupBuilder<GroupBuilder<GroupBuilder<T>>> group = builder
.group(getRepetition(descriptor)).as(OriginalType.MAP)
.group(Type.Repetition.REPEATED).as(OriginalType.MAP_KEY_VALUE)
.group(Repetition.OPTIONAL).as(OriginalType.MAP) // only optional maps are allowed in Proto3
.group(Type.Repetition.REPEATED)
.primitive(mapKeyParquetType.primitiveType, Type.Repetition.REQUIRED).as(mapKeyParquetType.originalType).named("key");

return addField(fields.get(1), group).named("value")
.named("map");
.named("key_value");
}

private ParquetType getParquetType(Descriptors.FieldDescriptor fieldDescriptor) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -21,7 +21,11 @@
import com.google.protobuf.ByteString;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.TextFormat;
Expand Down Expand Up @@ -203,11 +207,11 @@ private FieldWriter createMessageWriter(Descriptors.FieldDescriptor fieldDescrip

private GroupType getGroupType(Type type) {
if (type.getOriginalType() == OriginalType.LIST) {
return type.asGroupType().getType("list").asGroupType();
return type.asGroupType().getType("list").asGroupType().getType("element").asGroupType();
}

if (type.getOriginalType() == OriginalType.MAP) {
return type.asGroupType().getType("map").asGroupType().getType("value").asGroupType();
return type.asGroupType().getType("key_value").asGroupType().getType("value").asGroupType();
}

return type.asGroupType();
Expand Down Expand Up @@ -295,13 +299,21 @@ final void writeField(Object value) {
recordConsumer.startField("list", 0); // This is the wrapper group for the array field
for (Object listEntry: list) {
recordConsumer.startGroup();
if (isPrimitive(listEntry)) {
recordConsumer.startField("element", 0);

recordConsumer.startField("element", 0); // This is the mandatory inner field

if (!isPrimitive(listEntry)) {
recordConsumer.startGroup();
}

fieldWriter.writeRawValue(listEntry);
if (isPrimitive(listEntry)) {
recordConsumer.endField("element", 0);

if (!isPrimitive(listEntry)) {
recordConsumer.endGroup();
}

recordConsumer.endField("element", 0);

recordConsumer.endGroup();
}
recordConsumer.endField("list", 0);
Expand Down Expand Up @@ -369,15 +381,21 @@ public MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) {
final void writeRawValue(Object value) {
recordConsumer.startGroup();

recordConsumer.startField("map", 0); // This is the wrapper group for the map field
for(MapEntry<?, ?> entry : (Collection<MapEntry<?, ?>>) value) {
recordConsumer.startField("key_value", 0); // This is the wrapper group for the map field
for (Message msg : (Collection<Message>) value) {
recordConsumer.startGroup();
keyWriter.writeField(entry.getKey());
valueWriter.writeField(entry.getValue());

final Descriptor descriptorForType = msg.getDescriptorForType();
final FieldDescriptor keyDesc = descriptorForType.findFieldByName("key");
final FieldDescriptor valueDesc = descriptorForType.findFieldByName("value");

keyWriter.writeField(msg.getField(keyDesc));
valueWriter.writeField(msg.getField(valueDesc));

recordConsumer.endGroup();
}

recordConsumer.endField("map", 0);
recordConsumer.endField("key_value", 0);

recordConsumer.endGroup();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ public void testProto3ConvertAllDatatypes() throws Exception {
" optional binary optionalEnum (ENUM) = 18;" +
" optional int32 someInt32 = 19;" +
" optional binary someString (UTF8) = 20;" +
" repeated group optionalMap (MAP) = 21 {\n" +
" repeated group map (MAP_KEY_VALUE) {\n" +
" optional group optionalMap (MAP) = 21 {\n" +
" repeated group key_value {\n" +
" required int64 key;\n" +
" optional group value {\n" +
" optional int32 someId = 3;\n" +
Expand Down Expand Up @@ -135,7 +135,9 @@ public void testConvertRepetition() throws Exception {
" }\n" +
" required group repeatedMessage (LIST) = 9 {\n" +
" repeated group list {\n" +
" optional int32 someId = 3;\n" +
" optional group element {\n" +
" optional int32 someId = 3;\n" +
" }\n" +
" }\n" +
" }" +
"}";
Expand All @@ -158,7 +160,9 @@ public void testProto3ConvertRepetition() throws Exception {
" }\n" +
" required group repeatedMessage (LIST) = 9 {\n" +
" repeated group list {\n" +
" optional int32 someId = 3;\n" +
" optional group element {\n" +
" optional int32 someId = 3;\n" +
" }\n" +
" }\n" +
" }\n" +
"}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception {
inOrder.verify(readConsumerMock).startField("inner", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("list", 0);

inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("one", 0);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes()));
Expand All @@ -177,6 +180,9 @@ public void testRepeatedInnerMessageMessage_message() throws Exception {
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes()));
inOrder.verify(readConsumerMock).endField("two", 1);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);
inOrder.verify(readConsumerMock).endGroup();

inOrder.verify(readConsumerMock).endField("list", 0);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("inner", 0);
Expand All @@ -201,12 +207,18 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception {
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("list", 0);
inOrder.verify(readConsumerMock).startGroup();

inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("one", 0);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes()));
inOrder.verify(readConsumerMock).endField("one", 0);
inOrder.verify(readConsumerMock).startField("two", 1);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes()));
inOrder.verify(readConsumerMock).endField("two", 1);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);

inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("list", 0);
inOrder.verify(readConsumerMock).endGroup();
Expand Down Expand Up @@ -235,17 +247,25 @@ public void testRepeatedInnerMessageMessage_scalar() throws Exception {

//first inner message
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("one", 0);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes()));
inOrder.verify(readConsumerMock).endField("one", 0);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);
inOrder.verify(readConsumerMock).endGroup();

//second inner message
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("two", 1);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes()));
inOrder.verify(readConsumerMock).endField("two", 1);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);
inOrder.verify(readConsumerMock).endGroup();

inOrder.verify(readConsumerMock).endField("list", 0);
inOrder.verify(readConsumerMock).endGroup();
Expand Down Expand Up @@ -274,17 +294,25 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception {

//first inner message
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("one", 0);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("one".getBytes()));
inOrder.verify(readConsumerMock).endField("one", 0);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);
inOrder.verify(readConsumerMock).endGroup();

//second inner message
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("element", 0);
inOrder.verify(readConsumerMock).startGroup();
inOrder.verify(readConsumerMock).startField("two", 1);
inOrder.verify(readConsumerMock).addBinary(Binary.fromConstantByteArray("two".getBytes()));
inOrder.verify(readConsumerMock).endField("two", 1);
inOrder.verify(readConsumerMock).endGroup();
inOrder.verify(readConsumerMock).endField("element", 0);
inOrder.verify(readConsumerMock).endGroup();

inOrder.verify(readConsumerMock).endField("list", 0);
inOrder.verify(readConsumerMock).endGroup();
Expand Down

0 comments on commit dfa9701

Please sign in to comment.