Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4560,6 +4560,24 @@ public void testNestedColumnWithDuplicateName()
assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testParquetNaNStatistics()
{
String tableName = "test_parquet_nan_statistics";

assertUpdate("CREATE TABLE " + tableName + " (c_double DOUBLE, c_real REAL, c_string VARCHAR) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), cast(nan() as REAL), 'all nan')", 1);
assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), null, 'null real'), (null, nan(), 'null double')", 2);
assertUpdate("INSERT INTO " + tableName + " VALUES (nan(), 4.2, '4.2 real'), (4.2, nan(), '4.2 double')", 2);
assertUpdate("INSERT INTO " + tableName + " VALUES (0.1, 0.1, 'both 0.1')", 1);

// These assertions are intended to make sure we are handling NaN values in Parquet statistics,
// however Parquet file stats created in Presto don't include such values; the test is here mainly to prevent
// regressions, should a new writer start recording such stats
assertQuery("SELECT c_string FROM " + tableName + " WHERE c_double > 4", "VALUES ('4.2 double')");
assertQuery("SELECT c_string FROM " + tableName + " WHERE c_real > 4", "VALUES ('4.2 real')");
}

@Test
public void testMismatchedBucketing()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,13 @@ public static Domain getDomain(Type type, long rowCount, Statistics<?> statistic
failWithCorruptionException(failOnCorruptedParquetStatistics, column, id, floatStatistics);
return Domain.create(ValueSet.all(type), hasNullValue);
}
if (floatStatistics.genericGetMin().isNaN() || floatStatistics.genericGetMax().isNaN()) {
return Domain.create(ValueSet.all(type), hasNullValue);
}

ParquetIntegerStatistics parquetStatistics = new ParquetIntegerStatistics(
(long) floatToRawIntBits(floatStatistics.getMin()),
(long) floatToRawIntBits(floatStatistics.getMax()));

return createDomain(type, hasNullValue, parquetStatistics);
}

Expand All @@ -205,6 +207,10 @@ public static Domain getDomain(Type type, long rowCount, Statistics<?> statistic
failWithCorruptionException(failOnCorruptedParquetStatistics, column, id, doubleStatistics);
return Domain.create(ValueSet.all(type), hasNullValue);
}
if (doubleStatistics.genericGetMin().isNaN() || doubleStatistics.genericGetMax().isNaN()) {
return Domain.create(ValueSet.all(type), hasNullValue);
}

ParquetDoubleStatistics parquetDoubleStatistics = new ParquetDoubleStatistics(doubleStatistics.genericGetMin(), doubleStatistics.genericGetMax());
return createDomain(type, hasNullValue, parquetDoubleStatistics);
}
Expand Down Expand Up @@ -316,7 +322,11 @@ public static Domain getDomain(Type type, DictionaryDescriptor dictionaryDescrip
if (type.equals(DOUBLE) && columnDescriptor.getPrimitiveType().getPrimitiveTypeName() == PrimitiveTypeName.DOUBLE) {
List<Domain> domains = new ArrayList<>();
for (int i = 0; i < dictionarySize; i++) {
domains.add(Domain.singleValue(type, dictionary.decodeToDouble(i)));
double value = dictionary.decodeToDouble(i);
if (Double.isNaN(value)) {
return Domain.all(type);
}
domains.add(Domain.singleValue(type, value));
}
domains.add(Domain.onlyNull(type));
return Domain.union(domains);
Expand All @@ -325,7 +335,11 @@ public static Domain getDomain(Type type, DictionaryDescriptor dictionaryDescrip
if (type.equals(DOUBLE) && columnDescriptor.getPrimitiveType().getPrimitiveTypeName() == PrimitiveTypeName.FLOAT) {
List<Domain> domains = new ArrayList<>();
for (int i = 0; i < dictionarySize; i++) {
domains.add(Domain.singleValue(type, (double) dictionary.decodeToFloat(i)));
float value = dictionary.decodeToFloat(i);
if (Float.isNaN(value)) {
return Domain.all(type);
}
domains.add(Domain.singleValue(type, (double) value));
}
domains.add(Domain.onlyNull(type));
return Domain.union(domains);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import org.apache.parquet.bytes.LittleEndianDataOutputStream;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.statistics.BinaryStatistics;
import org.apache.parquet.column.statistics.BooleanStatistics;
Expand All @@ -38,6 +39,7 @@
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.ByteArrayOutputStream;
import java.math.BigDecimal;
import java.time.Instant;
import java.time.LocalDate;
Expand Down Expand Up @@ -69,6 +71,7 @@
import static io.prestosql.spi.type.TinyintType.TINYINT;
import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType;
import static io.prestosql.spi.type.VarcharType.createVarcharType;
import static java.lang.Float.NaN;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Math.toIntExact;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -237,7 +240,7 @@ public void testLongDecimal()

@Test
public void testDouble()
throws ParquetCorruptionException
throws Exception
{
String column = "DoubleColumn";
assertEquals(getDomain(DOUBLE, 0, null, ID, column, true), all(DOUBLE));
Expand All @@ -246,6 +249,18 @@ public void testDouble()

assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, 42.24), ID, column, true), create(ValueSet.ofRanges(range(DOUBLE, 3.3, true, 42.24, true)), false));

assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(NaN, NaN), ID, column, true), Domain.notNull(DOUBLE));

assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(NaN, NaN, true), ID, column, true), Domain.all(DOUBLE));

assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, NaN), ID, column, true), Domain.notNull(DOUBLE));

assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(3.3, NaN, true), ID, column, true), Domain.all(DOUBLE));

assertEquals(getDomain(DOUBLE, doubleDictionaryDescriptor(NaN)), Domain.all(DOUBLE));

assertEquals(getDomain(DOUBLE, doubleDictionaryDescriptor(3.3, NaN)), Domain.all(DOUBLE));

// ignore corrupted statistics
assertEquals(getDomain(DOUBLE, 10, doubleColumnStats(42.24, 3.3), ID, column, false), create(ValueSet.all(DOUBLE), false));
// fail on corrupted statistics
Expand All @@ -254,13 +269,6 @@ public void testDouble()
.withMessage("Corrupted statistics for column \"DoubleColumn\" in Parquet file \"testFile\": [min: 42.24, max: 3.3, num_nulls: 0]");
}

private static DoubleStatistics doubleColumnStats(double minimum, double maximum)
{
DoubleStatistics statistics = new DoubleStatistics();
statistics.setMinMax(minimum, maximum);
return statistics;
}

@Test
public void testString()
throws ParquetCorruptionException
Expand Down Expand Up @@ -291,7 +299,7 @@ private static BinaryStatistics stringColumnStats(String minimum, String maximum

@Test
public void testFloat()
throws ParquetCorruptionException
throws Exception
{
String column = "FloatColumn";
assertEquals(getDomain(REAL, 0, null, ID, column, true), all(REAL));
Expand All @@ -305,6 +313,18 @@ public void testFloat()
getDomain(REAL, 10, floatColumnStats(minimum, maximum), ID, column, true),
create(ValueSet.ofRanges(range(REAL, (long) floatToRawIntBits(minimum), true, (long) floatToRawIntBits(maximum), true)), false));

assertEquals(getDomain(REAL, 10, floatColumnStats(NaN, NaN), ID, column, true), Domain.notNull(REAL));

assertEquals(getDomain(REAL, 10, floatColumnStats(NaN, NaN, true), ID, column, true), Domain.all(REAL));

assertEquals(getDomain(REAL, 10, floatColumnStats(minimum, NaN), ID, column, true), Domain.notNull(REAL));

assertEquals(getDomain(REAL, 10, floatColumnStats(minimum, NaN, true), ID, column, true), Domain.all(REAL));

assertEquals(getDomain(REAL, floatDictionaryDescriptor(NaN)), Domain.all(REAL));

assertEquals(getDomain(REAL, floatDictionaryDescriptor(minimum, NaN)), Domain.all(REAL));

// ignore corrupted statistics
assertEquals(getDomain(REAL, 10, floatColumnStats(maximum, minimum), ID, column, false), create(ValueSet.all(REAL), false));
// fail on corrupted statistics
Expand Down Expand Up @@ -453,9 +473,32 @@ private TupleDomain<ColumnDescriptor> getEffectivePredicate(RichColumnDescriptor
}

private static FloatStatistics floatColumnStats(float minimum, float maximum)
{
return floatColumnStats(minimum, maximum, false);
}

private static FloatStatistics floatColumnStats(float minimum, float maximum, boolean hasNulls)
{
FloatStatistics statistics = new FloatStatistics();
statistics.setMinMax(minimum, maximum);
if (hasNulls) {
statistics.setNumNulls(1);
}
return statistics;
}

private static DoubleStatistics doubleColumnStats(double minimum, double maximum)
{
return doubleColumnStats(minimum, maximum, false);
}

private static DoubleStatistics doubleColumnStats(double minimum, double maximum, boolean hasNulls)
{
DoubleStatistics statistics = new DoubleStatistics();
statistics.setMinMax(minimum, maximum);
if (hasNulls) {
statistics.setNumNulls(1);
}
return statistics;
}

Expand All @@ -479,4 +522,32 @@ private static LongStatistics longOnlyNullsStats(long numNulls)
statistics.setNumNulls(numNulls);
return statistics;
}

private DictionaryDescriptor floatDictionaryDescriptor(float... values)
throws Exception
{
ByteArrayOutputStream buf = new ByteArrayOutputStream();
try (LittleEndianDataOutputStream out = new LittleEndianDataOutputStream(buf)) {
for (float val : values) {
out.writeFloat(val);
}
}
return new DictionaryDescriptor(
new ColumnDescriptor(new String[] {"dummy"}, new PrimitiveType(OPTIONAL, PrimitiveType.PrimitiveTypeName.FLOAT, 0, ""), 1, 1),
Optional.of(new DictionaryPage(Slices.wrappedBuffer(buf.toByteArray()), values.length, PLAIN_DICTIONARY)));
}

private DictionaryDescriptor doubleDictionaryDescriptor(double... values)
throws Exception
{
ByteArrayOutputStream buf = new ByteArrayOutputStream();
try (LittleEndianDataOutputStream out = new LittleEndianDataOutputStream(buf)) {
for (double val : values) {
out.writeDouble(val);
}
}
return new DictionaryDescriptor(
new ColumnDescriptor(new String[] {"dummy"}, new PrimitiveType(OPTIONAL, PrimitiveType.PrimitiveTypeName.DOUBLE, 0, ""), 1, 1),
Optional.of(new DictionaryPage(Slices.wrappedBuffer(buf.toByteArray()), values.length, PLAIN_DICTIONARY)));
}
}