diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java index c409071bd332d..8afaded8b96e3 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java @@ -96,6 +96,7 @@ import static com.facebook.presto.tests.StructuralTestUtil.mapType; import static com.google.common.base.Functions.compose; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.cycle; import static com.google.common.collect.Iterables.limit; @@ -181,6 +182,105 @@ public void testNestedArrays() tester.testRoundTrip(objectInspector, values, values, type); } + @Test + public void testNestedArraysDecimalBackedByINT32() + throws Exception + { + int precision = 1; + int scale = 0; + ObjectInspector objectInspector = getStandardListObjectInspector(javaIntObjectInspector); + Type type = new ArrayType(createDecimalType(precision, scale)); + Iterable> values = createTestArrays(intsBetween(1, 1_000)); + + ImmutableList.Builder> expectedValues = new ImmutableList.Builder<>(); + for (List value : values) { + expectedValues.add(value.stream() + .map(valueInt -> SqlDecimal.of(valueInt, precision, scale)) + .collect(toImmutableList())); + } + + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + + " optional group my_list (LIST){" + + " repeated group list {" + + " optional INT32 element (DECIMAL(%d, %d));" + + " }" + + " }" + + "} ", precision, scale)); + + tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema)); + } + + @Test + public void testNestedArraysDecimalBackedByINT64() + throws Exception + { + int precision = 10; + int scale = 2; + ObjectInspector objectInspector = getStandardListObjectInspector(javaLongObjectInspector); + Type type = new ArrayType(createDecimalType(precision, scale)); + Iterable> values = createTestArrays(longsBetween(1, 1_000)); + + ImmutableList.Builder> expectedValues = new ImmutableList.Builder<>(); + for (List value : values) { + expectedValues.add(value.stream() + .map(valueLong -> SqlDecimal.of(valueLong, precision, scale)) + .collect(toImmutableList())); + } + + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + + " optional group my_list (LIST){" + + " repeated group list {" + + " optional INT64 element (DECIMAL(%d, %d));" + + " }" + + " }" + + "} ", precision, scale)); + tester.testRoundTrip(objectInspector, values, expectedValues.build(), "my_list", type, Optional.of(hiveSchema)); + } + + @Test + public void testNestedArraysShortDecimalBackedByBinary() + throws Exception + { + int precision = 1; + int scale = 0; + ObjectInspector objectInspector = getStandardListObjectInspector(new JavaHiveDecimalObjectInspector(new DecimalTypeInfo(precision, scale))); + Type type = new ArrayType(createDecimalType(precision, scale)); + Iterable> values = getNestedDecimalArrayInputValues(precision, scale); + List> expectedValues = getNestedDecimalArrayExpectedValues(values, precision, scale); + + MessageType hiveSchema = parseMessageType(format("message hive_list_decimal {" + + " optional group my_list (LIST){" + + " repeated group list {" + + " optional BINARY element (DECIMAL(%d, %d));" + + " }" + + " }" + + "} ", precision, scale)); + + tester.testRoundTrip(objectInspector, values, expectedValues, "my_list", type, Optional.of(hiveSchema)); + } + + private Iterable> getNestedDecimalArrayInputValues(int precision, int scale) + { + ContiguousSet bigIntegerValues = bigIntegersBetween(BigDecimal.valueOf(Math.pow(10, precision - 1)).toBigInteger(), + BigDecimal.valueOf(Math.pow(10, precision)).toBigInteger()); + List writeValues = bigIntegerValues.stream() + .map(value -> HiveDecimal.create((BigInteger) value, scale)) + .collect(toImmutableList()); + + return createTestArrays(writeValues); + } + + private static List> getNestedDecimalArrayExpectedValues(Iterable> values, int precision, int scale) + { + ImmutableList.Builder> expectedValues = new ImmutableList.Builder<>(); + for (List value : values) { + expectedValues.add(value.stream() + .map(valueHiveDecimal -> new SqlDecimal(valueHiveDecimal.unscaledValue(), precision, scale)) + .collect(toImmutableList())); + } + return expectedValues.build(); + } + @Test public void testSingleLevelSchemaNestedArrays() throws Exception diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/ColumnReaderFactory.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/ColumnReaderFactory.java index 485935d579fc9..e03efd94d5882 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/ColumnReaderFactory.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/ColumnReaderFactory.java @@ -63,8 +63,8 @@ private ColumnReaderFactory() public static ColumnReader createReader(RichColumnDescriptor descriptor, boolean batchReadEnabled) { - if (batchReadEnabled) { - final boolean isNested = descriptor.getPath().length > 1; + final boolean isNested = descriptor.getPath().length > 1; + if (batchReadEnabled && (!(isNested && isDecimalType(descriptor)))) { switch (descriptor.getPrimitiveType().getPrimitiveTypeName()) { case BOOLEAN: return isNested ? new BooleanNestedBatchReader(descriptor) : new BooleanFlatBatchReader(descriptor);