From ff583e7174b1d2c568c92ad9cd74eb202a7da2ca Mon Sep 17 00:00:00 2001 From: v-jizhang Date: Fri, 17 Jun 2022 10:41:37 -0700 Subject: [PATCH 1/2] Refactor ParquetSchemaConverter to parameterize Repetition for types Cherry-pick of https://github.com/trinodb/trino/pull/12808/commits/f4aa94df20886ca2787f6c090cfa4869ad0aa595 Co-authored-by: Raunaq Morarka --- .../facebook/presto/common/type/RowType.java | 6 ++ .../writer/ParquetSchemaConverter.java | 55 +++++++++-------- .../writer/TestParquetSchemaConverter.java | 61 +++++++++++++++++++ 3 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java b/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java index 0740f0ebf6b6b..9e2595f917118 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java @@ -21,6 +21,7 @@ import com.facebook.presto.common.function.SqlFunctionProperties; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -64,6 +65,11 @@ public static RowType anonymous(List types) return new RowType(fields); } + public static RowType rowType(Field... field) + { + return from(Arrays.asList(field)); + } + public static RowType withDefaultFieldNames(List types) { List fields = new ArrayList<>(); diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java index af684972bf111..50e0690403bb5 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java @@ -28,6 +28,7 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; import java.util.HashMap; @@ -68,53 +69,53 @@ private MessageType convert(List types, List columnNames) { Types.MessageTypeBuilder builder = Types.buildMessage(); for (int i = 0; i < types.size(); i++) { - builder.addField(convert(types.get(i), columnNames.get(i), ImmutableList.of())); + builder.addField(convert(types.get(i), columnNames.get(i), ImmutableList.of(), OPTIONAL)); } return builder.named("presto_schema"); } - private org.apache.parquet.schema.Type convert(Type type, String name, List parent) + private org.apache.parquet.schema.Type convert(Type type, String name, List parent, Repetition repetition) { if (ROW.equals(type.getTypeSignature().getBase())) { - return getRowType((RowType) type, name, parent); + return getRowType((RowType) type, name, parent, repetition); } else if (MAP.equals(type.getTypeSignature().getBase())) { - return getMapType((MapType) type, name, parent); + return getMapType((MapType) type, name, parent, repetition); } else if (ARRAY.equals(type.getTypeSignature().getBase())) { - return getArrayType((ArrayType) type, name, parent); + return getArrayType((ArrayType) type, name, parent, repetition); } else { - return getPrimitiveType(type, name, parent); + return getPrimitiveType(type, name, parent, repetition); } } - private org.apache.parquet.schema.Type getPrimitiveType(Type type, String name, List parent) + private org.apache.parquet.schema.Type getPrimitiveType(Type type, String name, List parent, Repetition repetition) { List fullName = ImmutableList.builder().addAll(parent).add(name).build(); primitiveTypes.put(fullName, type); if (BOOLEAN.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.BOOLEAN, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.BOOLEAN, repetition).named(name); } if (INTEGER.equals(type) || SMALLINT.equals(type) || TINYINT.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition).named(name); } if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; if (decimalType.getPrecision() <= 9) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition) .as(OriginalType.DECIMAL) .precision(decimalType.getPrecision()) .scale(decimalType.getScale()).named(name); } else if (decimalType.isShort()) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition) .as(OriginalType.DECIMAL) .precision(decimalType.getPrecision()) .scale(decimalType.getScale()).named(name); } else { - return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + return Types.primitive(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, repetition) .length(16) .as(OriginalType.DECIMAL) .precision(decimalType.getPrecision()) @@ -122,49 +123,49 @@ else if (decimalType.isShort()) { } } if (DATE.equals(type)) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32).as(OriginalType.DATE).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition).as(OriginalType.DATE).named(name); } if (BIGINT.equals(type) || TIMESTAMP.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition).named(name); } if (DOUBLE.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.DOUBLE, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.DOUBLE, repetition).named(name); } if (RealType.REAL.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.FLOAT, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.FLOAT, repetition).named(name); } if (type instanceof VarcharType || type instanceof CharType || type instanceof VarbinaryType) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, repetition).named(name); } throw new PrestoException(NOT_SUPPORTED, format("Unsupported primitive type: %s", type)); } - private org.apache.parquet.schema.Type getArrayType(ArrayType type, String name, List parent) + private org.apache.parquet.schema.Type getArrayType(ArrayType type, String name, List parent, Repetition repetition) { Type elementType = type.getElementType(); - return Types.list(OPTIONAL) - .element(convert(elementType, "array", ImmutableList.builder().addAll(parent).add(name).add("list").build())) + return Types.list(repetition) + .element(convert(elementType, "array", ImmutableList.builder().addAll(parent).add(name).add("list").build(), OPTIONAL)) .named(name); } - private org.apache.parquet.schema.Type getMapType(MapType type, String name, List parent) + private org.apache.parquet.schema.Type getMapType(MapType type, String name, List parent, Repetition repetition) { parent = ImmutableList.builder().addAll(parent).add(name).add("key_value").build(); Type keyType = type.getKeyType(); Type valueType = type.getValueType(); - return Types.map(OPTIONAL) - .key(convert(keyType, "key", parent)) - .value(convert(valueType, "value", parent)) + return Types.map(repetition) + .key(convert(keyType, "key", parent, OPTIONAL)) + .value(convert(valueType, "value", parent, OPTIONAL)) .named(name); } - private org.apache.parquet.schema.Type getRowType(RowType type, String name, List parent) + private org.apache.parquet.schema.Type getRowType(RowType type, String name, List parent, Repetition repetition) { parent = ImmutableList.builder().addAll(parent).add(name).build(); - Types.GroupBuilder builder = Types.buildGroup(OPTIONAL); + Types.GroupBuilder builder = Types.buildGroup(repetition); for (RowType.Field field : type.getFields()) { com.google.common.base.Preconditions.checkArgument(field.getName().isPresent(), "field in struct type doesn't have name"); - builder.addField(convert(field.getType(), field.getName().get(), parent)); + builder.addField(convert(field.getType(), field.getName().get(), parent, OPTIONAL)); } return builder.named(name); } diff --git a/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java b/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java new file mode 100644 index 0000000000000..fb81995046c55 --- /dev/null +++ b/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.parquet.writer; + +import com.google.common.collect.ImmutableList; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.RowType.field; +import static com.facebook.presto.common.type.RowType.rowType; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.tests.StructuralTestUtil.mapType; +import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; +import static org.apache.parquet.schema.Type.Repetition.REPEATED; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestParquetSchemaConverter +{ + @Test + public void testMapKeyRepetitionLevel() + { + ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter( + ImmutableList.of(mapType(VARCHAR, INTEGER)), + ImmutableList.of("test")); + GroupType mapType = schemaConverter.getMessageType().getType(0).asGroupType(); + GroupType keyValueValue = mapType.getType(0).asGroupType(); + assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); + Type keyType = keyValueValue.getType(0).asPrimitiveType(); + assertThat(keyType.isRepetition(OPTIONAL)).isTrue(); + PrimitiveType valueType = keyValueValue.getType(1).asPrimitiveType(); + assertThat(valueType.isRepetition(OPTIONAL)).isTrue(); + + schemaConverter = new ParquetSchemaConverter( + ImmutableList.of(mapType(rowType(field("a", VARCHAR), field("b", BIGINT)), INTEGER)), + ImmutableList.of("test")); + mapType = schemaConverter.getMessageType().getType(0).asGroupType(); + keyValueValue = mapType.getType(0).asGroupType(); + assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); + keyType = keyValueValue.getType(0).asGroupType(); + assertThat(keyType.isRepetition(OPTIONAL)).isTrue(); + assertThat(keyType.asGroupType().getType(0).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); + assertThat(keyType.asGroupType().getType(1).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); + valueType = keyValueValue.getType(1).asPrimitiveType(); + assertThat(valueType.isRepetition(OPTIONAL)).isTrue(); + } +} From cc6d55e13d24b84b52bd63f6bba2a2e5542df73c Mon Sep 17 00:00:00 2001 From: v-jizhang Date: Fri, 17 Jun 2022 11:06:16 -0700 Subject: [PATCH 2/2] Use REQUIRED repetition level for MAP keys in parquet writer Cherry-pick of https://github.com/trinodb/trino/pull/12808/commits/5155e8653bbc664fe9c2d051d2acd90ea8b7e2fa As per parquet spec, MAP key should be REQUIRED. Co-authored-by: Raunaq Morarka --- .../parquet/AbstractTestParquetReader.java | 18 +++++++++++++++++- .../parquet/writer/ParquetSchemaConverter.java | 3 ++- .../writer/TestParquetSchemaConverter.java | 5 +++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java index 5fbf0d1446685..eeeae62204a01 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java @@ -475,7 +475,7 @@ public void testSingleLevelArrayOfMapOfArray() } @Test - public void testMapOfArray() + public void testMapOfArrayValues() throws Exception { Iterable> arrays = createNullableTestArrays(limit(cycle(asList(1, null, 3, 5, null, null, null, 7, 11, null, 13, 17)), 30_000)); @@ -487,6 +487,22 @@ public void testMapOfArray() values, values, mapType(INTEGER, new ArrayType(INTEGER))); } + @Test + public void testMapOfArrayKeys() + throws Exception + { + Iterable> mapKeys = createTestArrays(limit(cycle(asList(1, null, 3, 5, null, null, null, 7, 11, null, 13, 17)), 30_000)); + Iterable mapValues = intsBetween(0, 30_000); + Iterable, Integer>> testMaps = createTestMaps(mapKeys, mapValues); + tester.testRoundTrip( + getStandardMapObjectInspector( + getStandardListObjectInspector(javaIntObjectInspector), + javaIntObjectInspector), + testMaps, + testMaps, + mapType(new ArrayType(INTEGER), INTEGER)); + } + @Test public void testMapOfSingleLevelArray() throws Exception diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java index 50e0690403bb5..22a5396120723 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/writer/ParquetSchemaConverter.java @@ -51,6 +51,7 @@ import static java.util.Objects.requireNonNull; import static org.apache.parquet.Preconditions.checkArgument; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; public class ParquetSchemaConverter { @@ -154,7 +155,7 @@ private org.apache.parquet.schema.Type getMapType(MapType type, String name, Lis Type keyType = type.getKeyType(); Type valueType = type.getValueType(); return Types.map(repetition) - .key(convert(keyType, "key", parent, OPTIONAL)) + .key(convert(keyType, "key", parent, REQUIRED)) .value(convert(valueType, "value", parent, OPTIONAL)) .named(name); } diff --git a/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java b/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java index fb81995046c55..f42524c050dbd 100644 --- a/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java +++ b/presto-parquet/src/test/java/com/facebook/presto/parquet/writer/TestParquetSchemaConverter.java @@ -27,6 +27,7 @@ import static com.facebook.presto.tests.StructuralTestUtil.mapType; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; import static org.assertj.core.api.Assertions.assertThat; public class TestParquetSchemaConverter @@ -41,7 +42,7 @@ public void testMapKeyRepetitionLevel() GroupType keyValueValue = mapType.getType(0).asGroupType(); assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); Type keyType = keyValueValue.getType(0).asPrimitiveType(); - assertThat(keyType.isRepetition(OPTIONAL)).isTrue(); + assertThat(keyType.isRepetition(REQUIRED)).isTrue(); PrimitiveType valueType = keyValueValue.getType(1).asPrimitiveType(); assertThat(valueType.isRepetition(OPTIONAL)).isTrue(); @@ -52,7 +53,7 @@ public void testMapKeyRepetitionLevel() keyValueValue = mapType.getType(0).asGroupType(); assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); keyType = keyValueValue.getType(0).asGroupType(); - assertThat(keyType.isRepetition(OPTIONAL)).isTrue(); + assertThat(keyType.isRepetition(REQUIRED)).isTrue(); assertThat(keyType.asGroupType().getType(0).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); assertThat(keyType.asGroupType().getType(1).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); valueType = keyValueValue.getType(1).asPrimitiveType();