diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 632f19190452..2cf5aba0dd48 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -4560,6 +4560,24 @@ public void testNestedColumnWithDuplicateName() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testParquetNaNStatistics() + { + String tableName = "test_parquet_nan_statistics"; + + assertUpdate("CREATE TABLE " + tableName + " (c_double DOUBLE, c_real REAL, c_string VARCHAR) WITH (format = 'PARQUET')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), cast(nan() as REAL), 'all nan')", 1); + assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), null, 'null real'), (null, nan(), 'null double')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), 4.2, '4.2 real'), (4.2, nan(), '4.2 double')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (0.1, 0.1, 'both 0.1')", 1); + + // These assertions are intended to make sure we are handling NaN values in Parquet statistics, + // however Parquet file stats created in Presto don't include such values; the test is here mainly to prevent + // regressions, should a new writer start recording such stats + assertQuery("SELECT c_string FROM " + tableName + " WHERE c_double > 4", "VALUES ('4.2 double')"); + assertQuery("SELECT c_string FROM " + tableName + " WHERE c_real > 4", "VALUES ('4.2 real')"); + } + @Test public void testMismatchedBucketing() { diff --git a/presto-parquet/src/main/java/io/prestosql/parquet/predicate/TupleDomainParquetPredicate.java b/presto-parquet/src/main/java/io/prestosql/parquet/predicate/TupleDomainParquetPredicate.java index 5c8747ae4f89..b59a884f71d8 100644 --- a/presto-parquet/src/main/java/io/prestosql/parquet/predicate/TupleDomainParquetPredicate.java +++ b/presto-parquet/src/main/java/io/prestosql/parquet/predicate/TupleDomainParquetPredicate.java @@ -191,11 +191,13 @@ public static Domain getDomain(Type type, long rowCount, Statistics statistic failWithCorruptionException(failOnCorruptedParquetStatistics, column, id, floatStatistics); return Domain.create(ValueSet.all(type), hasNullValue); } + if (floatStatistics.genericGetMin().isNaN() || floatStatistics.genericGetMax().isNaN()) { + return Domain.create(ValueSet.all(type), hasNullValue); + } ParquetIntegerStatistics parquetStatistics = new ParquetIntegerStatistics( (long) floatToRawIntBits(floatStatistics.getMin()), (long) floatToRawIntBits(floatStatistics.getMax())); - return createDomain(type, hasNullValue, parquetStatistics); } @@ -205,6 +207,10 @@ public static Domain getDomain(Type type, long rowCount, Statistics statistic failWithCorruptionException(failOnCorruptedParquetStatistics, column, id, doubleStatistics); return Domain.create(ValueSet.all(type), hasNullValue); } + if (doubleStatistics.genericGetMin().isNaN() || doubleStatistics.genericGetMax().isNaN()) { + return Domain.create(ValueSet.all(type), hasNullValue); + } + ParquetDoubleStatistics parquetDoubleStatistics = new ParquetDoubleStatistics(doubleStatistics.genericGetMin(), doubleStatistics.genericGetMax()); return createDomain(type, hasNullValue, parquetDoubleStatistics); } @@ -316,7 +322,11 @@ public static Domain getDomain(Type type, DictionaryDescriptor dictionaryDescrip if (type.equals(DOUBLE) && columnDescriptor.getPrimitiveType().getPrimitiveTypeName() == PrimitiveTypeName.DOUBLE) { List domains = new ArrayList<>(); for (int i = 0; i < dictionarySize; i++) { - domains.add(Domain.singleValue(type, dictionary.decodeToDouble(i))); + double value = dictionary.decodeToDouble(i); + if (Double.isNaN(value)) { + return Domain.all(type); + } + domains.add(Domain.singleValue(type, value)); } domains.add(Domain.onlyNull(type)); return Domain.union(domains); @@ -325,7 +335,11 @@ public static Domain getDomain(Type type, DictionaryDescriptor dictionaryDescrip if (type.equals(DOUBLE) && columnDescriptor.getPrimitiveType().getPrimitiveTypeName() == PrimitiveTypeName.FLOAT) { List domains = new ArrayList<>(); for (int i = 0; i < dictionarySize; i++) { - domains.add(Domain.singleValue(type, (double) dictionary.decodeToFloat(i))); + float value = dictionary.decodeToFloat(i); + if (Float.isNaN(value)) { + return Domain.all(type); + } + domains.add(Domain.singleValue(type, (double) value)); } domains.add(Domain.onlyNull(type)); return Domain.union(domains); diff --git a/presto-parquet/src/test/java/io/prestosql/parquet/TestTupleDomainParquetPredicate.java b/presto-parquet/src/test/java/io/prestosql/parquet/TestTupleDomainParquetPredicate.java index da05ea822405..f1adf88fed56 100644 --- a/presto-parquet/src/test/java/io/prestosql/parquet/TestTupleDomainParquetPredicate.java +++ b/presto-parquet/src/test/java/io/prestosql/parquet/TestTupleDomainParquetPredicate.java @@ -25,6 +25,7 @@ import io.prestosql.spi.type.TimestampType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.VarcharType; +import org.apache.parquet.bytes.LittleEndianDataOutputStream; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.statistics.BinaryStatistics; import org.apache.parquet.column.statistics.BooleanStatistics; @@ -38,6 +39,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.io.ByteArrayOutputStream; import java.math.BigDecimal; import java.time.Instant; import java.time.LocalDate; @@ -69,6 +71,7 @@ import static io.prestosql.spi.type.TinyintType.TINYINT; import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static java.lang.Float.NaN; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.toIntExact; import static java.nio.charset.StandardCharsets.UTF_8; @@ -237,7 +240,7 @@ public void testLongDecimal() @Test public void testDouble() - throws ParquetCorruptionException + throws Exception { String column = "DoubleColumn"; assertEquals(getDomain(DOUBLE, 0, null, ID, column, true), all(DOUBLE)); @@ -246,6 +249,18 @@ public void testDouble() assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, 42.24), ID, column, true), create(ValueSet.ofRanges(range(DOUBLE, 3.3, true, 42.24, true)), false)); + assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(NaN, NaN), ID, column, true), Domain.notNull(DOUBLE)); + + assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(NaN, NaN, true), ID, column, true), Domain.all(DOUBLE)); + + assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, NaN), ID, column, true), Domain.notNull(DOUBLE)); + + assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, NaN, true), ID, column, true), Domain.all(DOUBLE)); + + assertEquals(getDomain(DOUBLE, doubleDictionaryDescriptor(NaN)), Domain.all(DOUBLE)); + + assertEquals(getDomain(DOUBLE, doubleDictionaryDescriptor(3.3, NaN)), Domain.all(DOUBLE)); + // ignore corrupted statistics assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(42.24, 3.3), ID, column, false), create(ValueSet.all(DOUBLE), false)); // fail on corrupted statistics @@ -254,13 +269,6 @@ public void testDouble() .withMessage("Corrupted statistics for column \"DoubleColumn\" in Parquet file \"testFile\": [min: 42.24, max: 3.3, num_nulls: 0]"); } - private static DoubleStatistics doubleColumnStats(double minimum, double maximum) - { - DoubleStatistics statistics = new DoubleStatistics(); - statistics.setMinMax(minimum, maximum); - return statistics; - } - @Test public void testString() throws ParquetCorruptionException @@ -291,7 +299,7 @@ private static BinaryStatistics stringColumnStats(String minimum, String maximum @Test public void testFloat() - throws ParquetCorruptionException + throws Exception { String column = "FloatColumn"; assertEquals(getDomain(REAL, 0, null, ID, column, true), all(REAL)); @@ -305,6 +313,18 @@ public void testFloat() getDomain(REAL, 10, floatColumnStats(minimum, maximum), ID, column, true), create(ValueSet.ofRanges(range(REAL, (long) floatToRawIntBits(minimum), true, (long) floatToRawIntBits(maximum), true)), false)); + assertEquals(getDomain(REAL, 10, floatColumnStats(NaN, NaN), ID, column, true), Domain.notNull(REAL)); + + assertEquals(getDomain(REAL, 10, floatColumnStats(NaN, NaN, true), ID, column, true), Domain.all(REAL)); + + assertEquals(getDomain(REAL, 10, floatColumnStats(minimum, NaN), ID, column, true), Domain.notNull(REAL)); + + assertEquals(getDomain(REAL, 10, floatColumnStats(minimum, NaN, true), ID, column, true), Domain.all(REAL)); + + assertEquals(getDomain(REAL, floatDictionaryDescriptor(NaN)), Domain.all(REAL)); + + assertEquals(getDomain(REAL, floatDictionaryDescriptor(minimum, NaN)), Domain.all(REAL)); + // ignore corrupted statistics assertEquals(getDomain(REAL, 10, floatColumnStats(maximum, minimum), ID, column, false), create(ValueSet.all(REAL), false)); // fail on corrupted statistics @@ -453,9 +473,32 @@ private TupleDomain getEffectivePredicate(RichColumnDescriptor } private static FloatStatistics floatColumnStats(float minimum, float maximum) + { + return floatColumnStats(minimum, maximum, false); + } + + private static FloatStatistics floatColumnStats(float minimum, float maximum, boolean hasNulls) { FloatStatistics statistics = new FloatStatistics(); statistics.setMinMax(minimum, maximum); + if (hasNulls) { + statistics.setNumNulls(1); + } + return statistics; + } + + private static DoubleStatistics doubleColumnStats(double minimum, double maximum) + { + return doubleColumnStats(minimum, maximum, false); + } + + private static DoubleStatistics doubleColumnStats(double minimum, double maximum, boolean hasNulls) + { + DoubleStatistics statistics = new DoubleStatistics(); + statistics.setMinMax(minimum, maximum); + if (hasNulls) { + statistics.setNumNulls(1); + } return statistics; } @@ -479,4 +522,32 @@ private static LongStatistics longOnlyNullsStats(long numNulls) statistics.setNumNulls(numNulls); return statistics; } + + private DictionaryDescriptor floatDictionaryDescriptor(float... values) + throws Exception + { + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + try (LittleEndianDataOutputStream out = new LittleEndianDataOutputStream(buf)) { + for (float val : values) { + out.writeFloat(val); + } + } + return new DictionaryDescriptor( + new ColumnDescriptor(new String[] {"dummy"}, new PrimitiveType(OPTIONAL, PrimitiveType.PrimitiveTypeName.FLOAT, 0, ""), 1, 1), + Optional.of(new DictionaryPage(Slices.wrappedBuffer(buf.toByteArray()), values.length, PLAIN_DICTIONARY))); + } + + private DictionaryDescriptor doubleDictionaryDescriptor(double... values) + throws Exception + { + ByteArrayOutputStream buf = new ByteArrayOutputStream(); + try (LittleEndianDataOutputStream out = new LittleEndianDataOutputStream(buf)) { + for (double val : values) { + out.writeDouble(val); + } + } + return new DictionaryDescriptor( + new ColumnDescriptor(new String[] {"dummy"}, new PrimitiveType(OPTIONAL, PrimitiveType.PrimitiveTypeName.DOUBLE, 0, ""), 1, 1), + Optional.of(new DictionaryPage(Slices.wrappedBuffer(buf.toByteArray()), values.length, PLAIN_DICTIONARY))); + } }