diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java index 21a526e7fa3a..3f455c6e9e9b 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetSchemaConverter.java @@ -30,6 +30,7 @@ import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; import java.util.HashMap; @@ -53,6 +54,7 @@ import static java.util.Objects.requireNonNull; import static org.apache.parquet.schema.LogicalTypeAnnotation.decimalType; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; public class ParquetSchemaConverter { @@ -86,120 +88,120 @@ private MessageType convert(List types, List columnNames) { Types.MessageTypeBuilder builder = Types.buildMessage(); for (int i = 0; i < types.size(); i++) { - builder.addField(convert(types.get(i), columnNames.get(i), ImmutableList.of())); + builder.addField(convert(types.get(i), columnNames.get(i), ImmutableList.of(), OPTIONAL)); } return builder.named("trino_schema"); } - private org.apache.parquet.schema.Type convert(Type type, String name, List parent) + private org.apache.parquet.schema.Type convert(Type type, String name, List parent, Repetition repetition) { if (ROW.equals(type.getTypeSignature().getBase())) { - return getRowType((RowType) type, name, parent); + return getRowType((RowType) type, name, parent, repetition); } else if (MAP.equals(type.getTypeSignature().getBase())) { - return getMapType((MapType) type, name, parent); + return getMapType((MapType) type, name, parent, repetition); } else if (ARRAY.equals(type.getTypeSignature().getBase())) { - return getArrayType((ArrayType) type, name, parent); + return getArrayType((ArrayType) type, name, parent, repetition); } else { - return getPrimitiveType(type, name, parent); + return getPrimitiveType(type, name, parent, repetition); } } - private org.apache.parquet.schema.Type getPrimitiveType(Type type, String name, List parent) + private org.apache.parquet.schema.Type getPrimitiveType(Type type, String name, List parent, Repetition repetition) { List fullName = ImmutableList.builder().addAll(parent).add(name).build(); primitiveTypes.put(fullName, type); if (BOOLEAN.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.BOOLEAN, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.BOOLEAN, repetition).named(name); } if (INTEGER.equals(type) || SMALLINT.equals(type) || TINYINT.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition).named(name); } if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; // Apache Hive version 3 or lower does not support reading decimals encoded as INT32/INT64 if (!useLegacyDecimalEncoding) { if (decimalType.getPrecision() <= 9) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32) + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition) .as(decimalType(decimalType.getScale(), decimalType.getPrecision())) .named(name); } if (decimalType.isShort()) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT64) + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition) .as(decimalType(decimalType.getScale(), decimalType.getPrecision())) .named(name); } } - return Types.optional(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + return Types.primitive(PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, repetition) .length(PRECISION_TO_BYTE_COUNT[decimalType.getPrecision()]) .as(decimalType(decimalType.getScale(), decimalType.getPrecision())) .named(name); } if (DATE.equals(type)) { - return Types.optional(PrimitiveType.PrimitiveTypeName.INT32).as(OriginalType.DATE).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT32, repetition).as(OriginalType.DATE).named(name); } if (BIGINT.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition).named(name); } if (type instanceof TimestampType) { TimestampType timestampType = (TimestampType) type; if (timestampType.getPrecision() <= 3) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, OPTIONAL).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MILLIS)).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MILLIS)).named(name); } if (timestampType.getPrecision() <= 6) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, OPTIONAL).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.MICROS)).named(name); } if (timestampType.getPrecision() <= 9) { // Per https://github.com/apache/parquet-format/blob/master/LogicalTypes.md, nanosecond precision timestamp should be stored as INT64 // even though it can only hold values within 1677-09-21 00:12:43 and 2262-04-11 23:47:16 range. - return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, OPTIONAL).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.INT64, repetition).as(LogicalTypeAnnotation.timestampType(false, LogicalTypeAnnotation.TimeUnit.NANOS)).named(name); } } if (DOUBLE.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.DOUBLE, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.DOUBLE, repetition).named(name); } if (RealType.REAL.equals(type)) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.FLOAT, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.FLOAT, repetition).named(name); } if (type instanceof VarcharType || type instanceof CharType) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, OPTIONAL).as(LogicalTypeAnnotation.stringType()).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, repetition).as(LogicalTypeAnnotation.stringType()).named(name); } if (type instanceof VarbinaryType) { - return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, OPTIONAL).named(name); + return Types.primitive(PrimitiveType.PrimitiveTypeName.BINARY, repetition).named(name); } throw new TrinoException(NOT_SUPPORTED, format("Unsupported primitive type: %s", type)); } - private org.apache.parquet.schema.Type getArrayType(ArrayType type, String name, List parent) + private org.apache.parquet.schema.Type getArrayType(ArrayType type, String name, List parent, Repetition repetition) { Type elementType = type.getElementType(); - return Types.list(OPTIONAL) - .element(convert(elementType, "array", ImmutableList.builder().addAll(parent).add(name).add("list").build())) + return Types.list(repetition) + .element(convert(elementType, "array", ImmutableList.builder().addAll(parent).add(name).add("list").build(), OPTIONAL)) .named(name); } - private org.apache.parquet.schema.Type getMapType(MapType type, String name, List parent) + private org.apache.parquet.schema.Type getMapType(MapType type, String name, List parent, Repetition repetition) { parent = ImmutableList.builder().addAll(parent).add(name).add("key_value").build(); Type keyType = type.getKeyType(); Type valueType = type.getValueType(); - return Types.map(OPTIONAL) - .key(convert(keyType, "key", parent)) - .value(convert(valueType, "value", parent)) + return Types.map(repetition) + .key(convert(keyType, "key", parent, REQUIRED)) + .value(convert(valueType, "value", parent, OPTIONAL)) .named(name); } - private org.apache.parquet.schema.Type getRowType(RowType type, String name, List parent) + private org.apache.parquet.schema.Type getRowType(RowType type, String name, List parent, Repetition repetition) { parent = ImmutableList.builder().addAll(parent).add(name).build(); - Types.GroupBuilder builder = Types.buildGroup(OPTIONAL); + Types.GroupBuilder builder = Types.buildGroup(repetition); for (RowType.Field field : type.getFields()) { checkArgument(field.getName().isPresent(), "field in struct type doesn't have name"); - builder.addField(convert(field.getType(), field.getName().get(), parent)); + builder.addField(convert(field.getType(), field.getName().get(), parent, OPTIONAL)); } return builder.named(name); } diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetSchemaConverter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetSchemaConverter.java index 2f77d7c9e6f6..082da0f8a231 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetSchemaConverter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetSchemaConverter.java @@ -14,15 +14,26 @@ package io.trino.parquet.writer; import com.google.common.collect.ImmutableList; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; import org.testng.annotations.Test; import java.math.BigInteger; import static io.trino.parquet.writer.ParquetSchemaConverter.HIVE_PARQUET_USE_LEGACY_DECIMAL_ENCODING; +import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.Decimals.MAX_PRECISION; import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RowType.field; +import static io.trino.spi.type.RowType.rowType; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.testing.StructuralTestUtil.mapType; +import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; +import static org.apache.parquet.schema.Type.Repetition.REPEATED; +import static org.apache.parquet.schema.Type.Repetition.REQUIRED; import static org.assertj.core.api.Assertions.assertThat; public class TestParquetSchemaConverter @@ -70,4 +81,34 @@ public void testDecimalTypeLengthWithLegacyEncoding() assertThat(bigInteger.toByteArray().length).isEqualTo(primitiveType.getTypeLength()); } } + + @Test + public void testMapKeyRepetitionLevel() + { + ParquetSchemaConverter schemaConverter = new ParquetSchemaConverter( + ImmutableList.of(mapType(VARCHAR, INTEGER)), + ImmutableList.of("test"), + false); + GroupType mapType = schemaConverter.getMessageType().getType(0).asGroupType(); + GroupType keyValueValue = mapType.getType(0).asGroupType(); + assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); + Type keyType = keyValueValue.getType(0).asPrimitiveType(); + assertThat(keyType.isRepetition(REQUIRED)).isTrue(); + PrimitiveType valueType = keyValueValue.getType(1).asPrimitiveType(); + assertThat(valueType.isRepetition(OPTIONAL)).isTrue(); + + schemaConverter = new ParquetSchemaConverter( + ImmutableList.of(mapType(rowType(field("a", VARCHAR), field("b", BIGINT)), INTEGER)), + ImmutableList.of("test"), + false); + mapType = schemaConverter.getMessageType().getType(0).asGroupType(); + keyValueValue = mapType.getType(0).asGroupType(); + assertThat(keyValueValue.isRepetition(REPEATED)).isTrue(); + keyType = keyValueValue.getType(0).asGroupType(); + assertThat(keyType.isRepetition(REQUIRED)).isTrue(); + assertThat(keyType.asGroupType().getType(0).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); + assertThat(keyType.asGroupType().getType(1).asPrimitiveType().isRepetition(OPTIONAL)).isTrue(); + valueType = keyValueValue.getType(1).asPrimitiveType(); + assertThat(valueType.isRepetition(OPTIONAL)).isTrue(); + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java index a605e243deaf..af6f30b55377 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/AbstractTestParquetReader.java @@ -459,7 +459,7 @@ public void testSingleLevelArrayOfMapOfArray() } @Test - public void testMapOfArray() + public void testMapOfArrayValues() throws Exception { Iterable> arrays = createNullableTestArrays(limit(cycle(asList(1, null, 3, 5, null, null, null, 7, 11, null, 13, 17)), 30_000)); @@ -471,6 +471,22 @@ public void testMapOfArray() values, values, mapType(INTEGER, new ArrayType(INTEGER))); } + @Test + public void testMapOfArrayKeys() + throws Exception + { + Iterable> mapKeys = createTestArrays(limit(cycle(asList(1, null, 3, 5, null, null, null, 7, 11, null, 13, 17)), 30_000)); + Iterable mapValues = intsBetween(0, 30_000); + Iterable, Integer>> testMaps = createTestMaps(mapKeys, mapValues); + tester.testRoundTrip( + getStandardMapObjectInspector( + getStandardListObjectInspector(javaIntObjectInspector), + javaIntObjectInspector), + testMaps, + testMaps, + mapType(new ArrayType(INTEGER), INTEGER)); + } + @Test public void testMapOfSingleLevelArray() throws Exception