diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/Statistics.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/Statistics.java index c5af7731205e..bbb03d736c5a 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/Statistics.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/Statistics.java @@ -13,6 +13,7 @@ */ package io.prestosql.plugin.hive.util; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import io.prestosql.plugin.hive.HiveBasicStatistics; import io.prestosql.plugin.hive.PartitionStatistics; @@ -353,7 +354,8 @@ private static Map> createColumnToComput .collect(toImmutableMap(Entry::getKey, entry -> ImmutableMap.copyOf(entry.getValue()))); } - private static HiveColumnStatistics createHiveColumnStatistics( + @VisibleForTesting + static HiveColumnStatistics createHiveColumnStatistics( ConnectorSession session, DateTimeZone timeZone, Map computedStatistics, @@ -436,7 +438,14 @@ private static OptionalLong getIntegerValue(ConnectorSession session, Type type, private static OptionalDouble getDoubleValue(ConnectorSession session, Type type, Block block) { - return block.isNull(0) ? OptionalDouble.empty() : OptionalDouble.of(((Number) type.getObjectValue(session, block, 0)).doubleValue()); + if (block.isNull(0)) { + return OptionalDouble.empty(); + } + double value = ((Number) type.getObjectValue(session, block, 0)).doubleValue(); + if (!Double.isFinite(value)) { + return OptionalDouble.empty(); + } + return OptionalDouble.of(value); } private static Optional getDateValue(ConnectorSession session, Type type, Block block) diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/util/TestStatistics.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/util/TestStatistics.java index 05267cd742b3..48659e156066 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/util/TestStatistics.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/util/TestStatistics.java @@ -21,6 +21,8 @@ import io.prestosql.plugin.hive.metastore.DoubleStatistics; import io.prestosql.plugin.hive.metastore.HiveColumnStatistics; import io.prestosql.plugin.hive.metastore.IntegerStatistics; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.statistics.ColumnStatisticType; import org.testng.annotations.Test; import java.math.BigDecimal; @@ -32,18 +34,88 @@ import static io.prestosql.plugin.hive.HiveBasicStatistics.createEmptyStatistics; import static io.prestosql.plugin.hive.HiveBasicStatistics.createZeroStatistics; +import static io.prestosql.plugin.hive.HiveTestUtils.SESSION; import static io.prestosql.plugin.hive.metastore.HiveColumnStatistics.createBinaryColumnStatistics; import static io.prestosql.plugin.hive.metastore.HiveColumnStatistics.createBooleanColumnStatistics; -import static io.prestosql.plugin.hive.metastore.HiveColumnStatistics.createDoubleColumnStatistics; import static io.prestosql.plugin.hive.metastore.HiveColumnStatistics.createIntegerColumnStatistics; import static io.prestosql.plugin.hive.util.Statistics.ReduceOperator.ADD; import static io.prestosql.plugin.hive.util.Statistics.ReduceOperator.SUBTRACT; +import static io.prestosql.plugin.hive.util.Statistics.createHiveColumnStatistics; import static io.prestosql.plugin.hive.util.Statistics.merge; import static io.prestosql.plugin.hive.util.Statistics.reduce; +import static io.prestosql.spi.predicate.Utils.nativeValueToBlock; +import static io.prestosql.spi.statistics.ColumnStatisticType.MAX_VALUE; +import static io.prestosql.spi.statistics.ColumnStatisticType.MIN_VALUE; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; +import static java.lang.Float.floatToIntBits; import static org.assertj.core.api.Assertions.assertThat; +import static org.joda.time.DateTimeZone.UTC; public class TestStatistics { + @Test + public void testCreateRealHiveColumnStatistics() + { + HiveColumnStatistics statistics; + + statistics = createRealColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(-2391f)), + MAX_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(42f)))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.of(-2391d), OptionalDouble.of(42))); + + statistics = createRealColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(Float.NEGATIVE_INFINITY)), + MAX_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(Float.POSITIVE_INFINITY)))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.empty(), OptionalDouble.empty())); + + statistics = createRealColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(Float.NaN)), + MAX_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(Float.NaN)))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.empty(), OptionalDouble.empty())); + + statistics = createRealColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(-15f)), + MAX_VALUE, nativeValueToBlock(REAL, (long) floatToIntBits(-0f)))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.of(-15d), OptionalDouble.of(-0d))); // TODO should we distinguish between -0 and 0? + } + + private static HiveColumnStatistics createRealColumnStatistics(ImmutableMap computedStatistics) + { + return createHiveColumnStatistics(SESSION, UTC, computedStatistics, REAL, 1); + } + + @Test + public void testCreateDoubleHiveColumnStatistics() + { + HiveColumnStatistics statistics; + + statistics = createDoubleColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(DOUBLE, -2391d), + MAX_VALUE, nativeValueToBlock(DOUBLE, 42d))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.of(-2391d), OptionalDouble.of(42))); + + statistics = createDoubleColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(DOUBLE, Double.NEGATIVE_INFINITY), + MAX_VALUE, nativeValueToBlock(DOUBLE, Double.POSITIVE_INFINITY))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.empty(), OptionalDouble.empty())); + + statistics = createDoubleColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(DOUBLE, Double.NaN), + MAX_VALUE, nativeValueToBlock(DOUBLE, Double.NaN))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.empty(), OptionalDouble.empty())); + + statistics = createDoubleColumnStatistics(ImmutableMap.of( + MIN_VALUE, nativeValueToBlock(DOUBLE, -15d), + MAX_VALUE, nativeValueToBlock(DOUBLE, -0d))); + assertThat(statistics.getDoubleStatistics().get()).isEqualTo(new DoubleStatistics(OptionalDouble.of(-15d), OptionalDouble.of(-0d))); // TODO should we distinguish between -0 and 0? + } + + private static HiveColumnStatistics createDoubleColumnStatistics(ImmutableMap computedStatistics) + { + return createHiveColumnStatistics(SESSION, UTC, computedStatistics, DOUBLE, 1); + } + @Test public void testReduce() { @@ -219,16 +291,16 @@ public void testMergeHiveColumnStatisticsMap() { Map first = ImmutableMap.of( "column1", createIntegerColumnStatistics(OptionalLong.of(1), OptionalLong.of(2), OptionalLong.of(3), OptionalLong.of(4)), - "column2", createDoubleColumnStatistics(OptionalDouble.of(2), OptionalDouble.of(3), OptionalLong.of(4), OptionalLong.of(5)), + "column2", HiveColumnStatistics.createDoubleColumnStatistics(OptionalDouble.of(2), OptionalDouble.of(3), OptionalLong.of(4), OptionalLong.of(5)), "column3", createBinaryColumnStatistics(OptionalLong.of(5), OptionalLong.of(5), OptionalLong.of(10)), "column4", createBooleanColumnStatistics(OptionalLong.of(1), OptionalLong.of(2), OptionalLong.of(3))); Map second = ImmutableMap.of( "column5", createIntegerColumnStatistics(OptionalLong.of(1), OptionalLong.of(2), OptionalLong.of(3), OptionalLong.of(4)), - "column2", createDoubleColumnStatistics(OptionalDouble.of(1), OptionalDouble.of(4), OptionalLong.of(4), OptionalLong.of(6)), + "column2", HiveColumnStatistics.createDoubleColumnStatistics(OptionalDouble.of(1), OptionalDouble.of(4), OptionalLong.of(4), OptionalLong.of(6)), "column3", createBinaryColumnStatistics(OptionalLong.of(6), OptionalLong.of(5), OptionalLong.of(10)), "column6", createBooleanColumnStatistics(OptionalLong.of(1), OptionalLong.of(2), OptionalLong.of(3))); Map expected = ImmutableMap.of( - "column2", createDoubleColumnStatistics(OptionalDouble.of(1), OptionalDouble.of(4), OptionalLong.of(8), OptionalLong.of(6)), + "column2", HiveColumnStatistics.createDoubleColumnStatistics(OptionalDouble.of(1), OptionalDouble.of(4), OptionalLong.of(8), OptionalLong.of(6)), "column3", createBinaryColumnStatistics(OptionalLong.of(6), OptionalLong.of(10), OptionalLong.of(20))); assertThat(merge(first, second)).isEqualTo(expected); assertThat(merge(ImmutableMap.of(), ImmutableMap.of())).isEqualTo(ImmutableMap.of()); diff --git a/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java b/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java index f0e1165bed19..58c25db48d9e 100644 --- a/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java +++ b/presto-product-tests/src/main/java/io/prestosql/tests/hive/TestHiveTableStatistics.java @@ -23,9 +23,11 @@ import io.prestosql.tempto.fulfillment.table.MutableTableRequirement; import io.prestosql.tempto.fulfillment.table.hive.HiveTableDefinition; import io.prestosql.tempto.fulfillment.table.hive.InlineDataSource; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; +import java.util.Objects; import static io.prestosql.tempto.assertions.QueryAssert.Row.row; import static io.prestosql.tempto.assertions.QueryAssert.anyOf; @@ -1363,6 +1365,53 @@ public void testComputePartitionStatisticsOnInsert() } } + @Test(dataProvider = "testComputeFloatingPointStatisticsDataProvider") + public void testComputeFloatingPointStatistics(String dataType) + { + String tableName = "test_compute_floating_point_statistics"; + query("DROP TABLE IF EXISTS " + tableName); + try { + query(format("CREATE TABLE %1$s(c_basic %2$s, c_minmax %2$s, c_inf %2$s, c_ninf %2$s, c_nan %2$s, c_nzero %2$s)", tableName, dataType)); + query("ANALYZE " + tableName); // TODO remove after https://github.com/prestosql/presto/issues/2469 + + query(format( + "INSERT INTO %1$s(c_basic, c_minmax, c_inf, c_ninf, c_nan, c_nzero) VALUES " + + " (%2$s '42.3', %2$s '576234.567', %2$s 'Infinity', %2$s '-Infinity', %2$s 'NaN', %2$s '-0')," + + " (%2$s '42.3', %2$s '-1234567.89', %2$s '-15', %2$s '45', %2$s '12345', %2$s '-47'), " + + " (NULL, NULL, NULL, NULL, NULL, NULL)", + tableName, + dataType)); + + List expectedStatistics = ImmutableList.of( + row("c_basic", null, 1., 0.33333333333, null, "42.3", "42.3"), + Objects.equals(dataType, "double") + ? row("c_minmax", null, 2., 0.33333333333, null, "-1234567.89", "576234.567") + : row("c_minmax", null, 2., 0.33333333333, null, "-1234567.9", "576234.56"), + row("c_inf", null, 2., 0.33333333333, null, null, null), // -15, +inf + row("c_ninf", null, 2., 0.33333333333, null, null, null), // -inf, 45 + row("c_nan", null, 2., 0.33333333333, null, null, null), // 12345., NaN + row("c_nzero", null, 2., 0.33333333333, null, "-47.0", "0.0"), + row(null, null, null, null, 3., null, null)); + + assertThat(query("SHOW STATS FOR " + tableName)).containsOnly(expectedStatistics); + + query("ANALYZE " + tableName); + assertThat(query("SHOW STATS FOR " + tableName)).containsOnly(expectedStatistics); + } + finally { + query("DROP TABLE IF EXISTS " + tableName); + } + } + + @DataProvider + public Object[][] testComputeFloatingPointStatisticsDataProvider() + { + return new Object[][] { + {"real"}, + {"double"}, + }; + } + private static void assertComputeTableStatisticsOnCreateTable(String sourceTableName, List expectedStatistics) { String copiedTableName = "assert_compute_table_statistics_on_create_table_" + sourceTableName;