diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java index b3e54223859e..4febe2e8edf7 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveMetadata.java @@ -243,6 +243,7 @@ import static io.prestosql.spi.predicate.TupleDomain.withColumnDomains; import static io.prestosql.spi.statistics.TableStatisticType.ROW_COUNT; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -2217,8 +2218,9 @@ private static Domain buildColumnDomain(ColumnHandle column, List checkArgument(!partitions.isEmpty(), "partitions cannot be empty"); boolean hasNull = false; + boolean hasNaN = false; List nonNullValues = new ArrayList<>(); - Type type = null; + Type type = ((HiveColumnHandle) column).getType(); for (HivePartition partition : partitions) { NullableValue value = partition.getKeys().get(column); @@ -2230,24 +2232,29 @@ private static Domain buildColumnDomain(ColumnHandle column, List hasNull = true; } else { + if (isFloatingPointNaN(type, value.getValue())) { + hasNaN = true; + } nonNullValues.add(value.getValue()); } - - if (type == null) { - type = value.getType(); - } } - if (!nonNullValues.isEmpty()) { - Domain domain = Domain.multipleValues(type, nonNullValues); - if (hasNull) { - return domain.union(Domain.onlyNull(type)); - } + Domain domain; + if (nonNullValues.isEmpty()) { + domain = Domain.none(type); + } + else if (hasNaN) { + domain = Domain.notNull(type); + } + else { + domain = Domain.multipleValues(type, nonNullValues); + } - return domain; + if (hasNull) { + domain = domain.union(Domain.onlyNull(type)); } - return Domain.onlyNull(type); + return domain; } @Override diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index 8c1115ba6160..f36de3af59a1 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -336,6 +336,29 @@ public void testJoinPartitionedWithMissingPartitionFilter() assertUpdate(session, "DROP TABLE partition_test8"); } + @Test + public void testNaNPartition() + { + assertUpdate("DROP TABLE IF EXISTS test_nan_partition"); + assertUpdate("CREATE TABLE test_nan_partition(a varchar, d double) WITH (partitioned_by = ARRAY['d'])"); + assertUpdate("INSERT INTO test_nan_partition VALUES ('a', 42e0), ('b', nan())", 2); + + assertQuery( + "SELECT a, d, regexp_replace(\"$path\", '.*(/[^/]*/[^/]*/)[^/]*', '...$1...') FROM test_nan_partition", + "VALUES " + + " ('a', 42, '.../test_nan_partition/d=42.0/...'), " + + " ('b', SQRT(-1), '.../test_nan_partition/d=NaN/...')"); // SQRT(-1) is H2's recommended way to obtain NaN + + assertQueryReturnsEmptyResult("SELECT a FROM test_nan_partition JOIN (VALUES 33e0) u(x) ON d = x"); + assertQueryReturnsEmptyResult("SELECT a FROM test_nan_partition JOIN (VALUES 33e0) u(x) ON d = x OR rand() = 42"); + + assertQuery( + "SELECT * FROM test_nan_partition t1 JOIN test_nan_partition t2 ON t1.d = t2.d", + "VALUES ('a', 42, 'a', 42)"); + + assertUpdate("DROP TABLE test_nan_partition"); + } + @Test public void testJoinWithPartitionFilterOnPartionedTable() { @@ -3297,7 +3320,8 @@ private void testCreateExternalTable( } @Test - public void testCreateExternalTable() throws Exception + public void testCreateExternalTable() + throws Exception { testCreateExternalTable( "test_create_external", @@ -3307,7 +3331,8 @@ public void testCreateExternalTable() throws Exception } @Test - public void testCreateExternalTableWithFieldSeparator() throws Exception + public void testCreateExternalTableWithFieldSeparator() + throws Exception { testCreateExternalTable( "test_create_external", @@ -3317,7 +3342,8 @@ public void testCreateExternalTableWithFieldSeparator() throws Exception } @Test - public void testCreateExternalTableWithFieldSeparatorEscape() throws Exception + public void testCreateExternalTableWithFieldSeparatorEscape() + throws Exception { testCreateExternalTable( "test_create_external_text_file_with_field_separator_and_escape", @@ -3329,7 +3355,8 @@ public void testCreateExternalTableWithFieldSeparatorEscape() throws Exception } @Test - public void testCreateExternalTableWithNullFormat() throws Exception + public void testCreateExternalTableWithNullFormat() + throws Exception { testCreateExternalTable( "test_create_external_textfile_with_null_format", diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java index b9fea88293c1..3580c84971e1 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveMetadata.java @@ -15,16 +15,23 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.slice.Slices; import io.prestosql.spi.connector.SchemaTableName; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.NullableValue; +import io.prestosql.spi.predicate.ValueSet; import org.testng.annotations.Test; import java.util.Optional; +import java.util.stream.IntStream; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.Slices.utf8Slice; import static io.prestosql.plugin.hive.HiveColumnHandle.createBaseColumn; import static io.prestosql.plugin.hive.HiveMetadata.createPredicate; +import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; public class TestHiveMetadata { @@ -36,6 +43,14 @@ public class TestHiveMetadata HiveColumnHandle.ColumnType.PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle DOUBLE_COLUMN_HANDLE = createBaseColumn( + "test", + 0, + HiveType.HIVE_DOUBLE, + DOUBLE, + HiveColumnHandle.ColumnType.PARTITION_KEY, + Optional.empty()); + @Test(timeOut = 10_000) public void testCreatePredicate() { @@ -45,10 +60,17 @@ public void testCreatePredicate() partitions.add(new HivePartition( new SchemaTableName("test", "test"), Integer.toString(i), - ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, Slices.utf8Slice(Integer.toString(i)))))); + ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, utf8Slice(Integer.toString(i)))))); } - createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()); + Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()) + .getDomains().orElseThrow().get(TEST_COLUMN_HANDLE); + assertEquals(domain, Domain.create( + ValueSet.copyOf(VARCHAR, + IntStream.range(0, 5_000) + .mapToObj(i -> utf8Slice(Integer.toString(i))) + .collect(toImmutableList())), + false)); } @Test @@ -63,7 +85,56 @@ public void testCreateOnlyNullsPredicate() ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.asNull(VARCHAR)))); } - createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()); + Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()) + .getDomains().orElseThrow().get(TEST_COLUMN_HANDLE); + assertEquals(domain, Domain.onlyNull(VARCHAR)); + } + + @Test + public void testCreatePredicateWithNaN() + { + HiveColumnHandle columnHandle = DOUBLE_COLUMN_HANDLE; + ImmutableList.Builder partitions = ImmutableList.builder(); + + partitions.add(new HivePartition( + new SchemaTableName("test", "test"), + "p1", + ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, Double.NaN)))); + + partitions.add(new HivePartition( + new SchemaTableName("test", "test"), + "p2", + ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, 4.2)))); + + Domain domain = createPredicate(ImmutableList.of(columnHandle), partitions.build()) + .getDomains().orElseThrow().get(columnHandle); + assertEquals(domain, Domain.notNull(DOUBLE)); + } + + @Test + public void testCreatePredicateWithNaNAndNull() + { + HiveColumnHandle columnHandle = DOUBLE_COLUMN_HANDLE; + ImmutableList.Builder partitions = ImmutableList.builder(); + + partitions.add(new HivePartition( + new SchemaTableName("test", "test"), + "p1", + ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, Double.NaN)))); + + partitions.add(new HivePartition( + new SchemaTableName("test", "test"), + "p2", + ImmutableMap.of(columnHandle, NullableValue.of(DOUBLE, 4.2)))); + + partitions.add(new HivePartition( + new SchemaTableName("test", "test"), + "p3", + ImmutableMap.of(columnHandle, NullableValue.asNull(DOUBLE)))); + + Domain domain = createPredicate(ImmutableList.of(columnHandle), partitions.build()) + .getDomains().orElseThrow().get(columnHandle); + assertNull(domain); } @Test @@ -75,7 +146,7 @@ public void testCreateMixedPredicate() partitions.add(new HivePartition( new SchemaTableName("test", "test"), Integer.toString(i), - ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, Slices.utf8Slice(Integer.toString(i)))))); + ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.of(VARCHAR, utf8Slice(Integer.toString(i)))))); } partitions.add(new HivePartition( @@ -83,6 +154,8 @@ public void testCreateMixedPredicate() "null", ImmutableMap.of(TEST_COLUMN_HANDLE, NullableValue.asNull(VARCHAR)))); - createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()); + Domain domain = createPredicate(ImmutableList.of(TEST_COLUMN_HANDLE), partitions.build()) + .getDomains().orElseThrow().get(TEST_COLUMN_HANDLE); + assertEquals(domain, Domain.create(ValueSet.of(VARCHAR, utf8Slice("0"), utf8Slice("1"), utf8Slice("2"), utf8Slice("3"), utf8Slice("4")), true)); } } diff --git a/presto-main/src/main/java/io/prestosql/operator/DynamicFilterSourceOperator.java b/presto-main/src/main/java/io/prestosql/operator/DynamicFilterSourceOperator.java index 2dd068a2dfd0..b21cd830ee76 100644 --- a/presto-main/src/main/java/io/prestosql/operator/DynamicFilterSourceOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/DynamicFilterSourceOperator.java @@ -36,6 +36,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; @@ -253,9 +254,13 @@ private Domain convertToDomain(Type type, Block block) for (int position = 0; position < block.getPositionCount(); ++position) { Object value = TypeUtils.readNativeValue(type, block, position); if (value != null) { - values.add(value); + // join doesn't match rows with NaN values. + if (!isFloatingPointNaN(type, value)) { + values.add(value); + } } } + // Inner and right join doesn't match rows with null key column values. return Domain.create(ValueSet.copyOf(type, values.build()), false); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index c409b8909bad..6814cc2ca7b4 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -77,6 +77,7 @@ import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; import static io.airlift.slice.SliceUtf8.setCodePointAt; import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static io.prestosql.sql.ExpressionUtils.and; import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.combineDisjunctsWithDefault; @@ -89,8 +90,6 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; import static java.util.stream.Collectors.toList; @@ -621,8 +620,7 @@ private static Optional extractOrderableDomain(ComparisonExpression.Oper } // Handle comparisons against NaN - if ((type instanceof DoubleType && Double.isNaN((double) value)) || - (type instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) value))))) { + if (isFloatingPointNaN(type, value)) { switch (comparisonOperator) { case EQUAL: case GREATER_THAN: diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java index 337c78b4a0e4..28658b9bc47a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/EffectivePredicateExtractor.java @@ -60,6 +60,7 @@ import java.util.stream.Collectors; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.expressionOrNullSymbols; import static io.prestosql.sql.ExpressionUtils.extractConjuncts; @@ -332,6 +333,7 @@ public Expression visitValues(ValuesNode node, Void context) ImmutableList.Builder builder = ImmutableList.builder(); boolean hasNull = false; + boolean hasNaN = false; boolean nonDeterministic = false; for (int row = 0; row < node.getRows().size(); row++) { Expression value = node.getRows().get(row).get(column); @@ -352,6 +354,9 @@ public Expression visitValues(ValuesNode node, Void context) hasNull = true; } else { + if (isFloatingPointNaN(type, evaluated)) { + hasNaN = true; + } builder.add(evaluated); } } @@ -364,10 +369,15 @@ public Expression visitValues(ValuesNode node, Void context) List values = builder.build(); - Domain domain = Domain.none(type); - - if (!values.isEmpty()) { - domain = domain.union(Domain.multipleValues(type, values)); + Domain domain; + if (values.isEmpty()) { + domain = Domain.none(type); + } + else if (hasNaN) { + domain = Domain.notNull(type); + } + else { + domain = Domain.multipleValues(type, values); } if (hasNull) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java index a62b8d928f49..cc676f8aeb6e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -46,6 +46,7 @@ import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.IntegerType.INTEGER; import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static io.prestosql.sql.ExpressionUtils.and; import static io.prestosql.sql.ExpressionUtils.or; import static io.prestosql.sql.analyzer.TypeSignatureTranslator.toSqlType; @@ -188,7 +189,7 @@ private Expression unwrapCast(ComparisonExpression expression) // Handle comparison against NaN. // It must be done before source type range bounds are compared to target value. - if ((targetType instanceof DoubleType && Double.isNaN((double) right)) || (targetType instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) right))))) { + if (isFloatingPointNaN(targetType, right)) { switch (operator) { case EQUAL: case GREATER_THAN: diff --git a/presto-main/src/test/java/io/prestosql/operator/TestDynamicFilterSourceOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestDynamicFilterSourceOperator.java index 1426b3d54123..11ecc62c02d3 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestDynamicFilterSourceOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestDynamicFilterSourceOperator.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.predicate.ValueSet; @@ -50,9 +51,11 @@ import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static io.prestosql.testing.assertions.Assert.assertEquals; +import static java.lang.Float.floatToRawIntBits; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.stream.Collectors.toList; @@ -217,6 +220,42 @@ public void testCollectWithNulls() new DynamicFilterId("0"), Domain.create(ValueSet.of(INTEGER, 1L, 2L, 3L, 4L, 5L), false))))); } + @Test + public void testCollectWithDoubleNaN() + { + BlockBuilder input = DOUBLE.createBlockBuilder(null, 10); + DOUBLE.writeDouble(input, 42.0); + DOUBLE.writeDouble(input, Double.NaN); + + OperatorFactory operatorFactory = createOperatorFactory(channel(0, DOUBLE)); + verifyPassthrough(createOperator(operatorFactory), + ImmutableList.of(DOUBLE), + new Page(input.build())); + operatorFactory.noMoreOperators(); + + assertEquals(partitions.build(), ImmutableList.of( + TupleDomain.withColumnDomains(ImmutableMap.of( + new DynamicFilterId("0"), Domain.multipleValues(DOUBLE, ImmutableList.of(42.0)))))); + } + + @Test + public void testCollectWithRealNaN() + { + BlockBuilder input = REAL.createBlockBuilder(null, 10); + REAL.writeLong(input, floatToRawIntBits(42.0f)); + REAL.writeLong(input, floatToRawIntBits(Float.NaN)); + + OperatorFactory operatorFactory = createOperatorFactory(channel(0, REAL)); + verifyPassthrough(createOperator(operatorFactory), + ImmutableList.of(REAL), + new Page(input.build())); + operatorFactory.noMoreOperators(); + + assertEquals(partitions.build(), ImmutableList.of( + TupleDomain.withColumnDomains(ImmutableMap.of( + new DynamicFilterId("0"), Domain.multipleValues(REAL, ImmutableList.of((long) floatToRawIntBits(42.0f))))))); + } + @Test public void testCollectNoFilters() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java index debf237c8f40..e0912b14801a 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestEffectivePredicateExtractor.java @@ -64,6 +64,7 @@ import io.prestosql.sql.tree.BooleanLiteral; import io.prestosql.sql.tree.Cast; import io.prestosql.sql.tree.ComparisonExpression; +import io.prestosql.sql.tree.DoubleLiteral; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionTreeRewriter; import io.prestosql.sql.tree.FunctionCall; @@ -72,6 +73,7 @@ import io.prestosql.sql.tree.InPredicate; import io.prestosql.sql.tree.IsNullPredicate; import io.prestosql.sql.tree.LongLiteral; +import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.NullLiteral; import io.prestosql.sql.tree.QualifiedName; import io.prestosql.testing.TestingMetadata.TestingColumnHandle; @@ -96,6 +98,8 @@ import static io.prestosql.metadata.FunctionId.toFunctionId; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.sql.ExpressionUtils.and; import static io.prestosql.sql.ExpressionUtils.combineConjuncts; import static io.prestosql.sql.ExpressionUtils.or; @@ -522,6 +526,7 @@ public void testValues() TypeProvider types = TypeProvider.copyOf(ImmutableMap.builder() .put(A, BIGINT) .put(B, BIGINT) + .put(D, DOUBLE) .build()); // one column @@ -583,6 +588,58 @@ public void testValues() typeAnalyzer), new BetweenPredicate(AE, bigintLiteral(0), bigintLiteral(499))); + // NaN + assertEquals( + effectivePredicateExtractor.extract( + SESSION, + new ValuesNode( + newId(), + ImmutableList.of(D), + ImmutableList.of(ImmutableList.of(doubleLiteral(Double.NaN)))), + types, + typeAnalyzer), + new NotExpression(new IsNullPredicate(DE))); + + // NaN and NULL + assertEquals( + effectivePredicateExtractor.extract( + SESSION, + new ValuesNode( + newId(), + ImmutableList.of(D), + ImmutableList.of( + ImmutableList.of(new Cast(new NullLiteral(), toSqlType(DOUBLE))), + ImmutableList.of(doubleLiteral(Double.NaN)))), + types, + typeAnalyzer), + TRUE_LITERAL); + + // NaN and value + assertEquals( + effectivePredicateExtractor.extract( + SESSION, + new ValuesNode( + newId(), + ImmutableList.of(D), + ImmutableList.of( + ImmutableList.of(doubleLiteral(42.)), + ImmutableList.of(doubleLiteral(Double.NaN)))), + types, + typeAnalyzer), + new NotExpression(new IsNullPredicate(DE))); + + // Real NaN + assertEquals( + effectivePredicateExtractor.extract( + SESSION, + new ValuesNode( + newId(), + ImmutableList.of(D), + ImmutableList.of(ImmutableList.of(new Cast(doubleLiteral(Double.NaN), toSqlType(REAL))))), + TypeProvider.copyOf(ImmutableMap.of(D, REAL)), + typeAnalyzer), + new NotExpression(new IsNullPredicate(DE))); + // multiple columns assertEquals( effectivePredicateExtractor.extract( @@ -1037,6 +1094,11 @@ private static Expression bigintLiteral(long number) return new LongLiteral(String.valueOf(number)); } + private static Expression doubleLiteral(double value) + { + return new DoubleLiteral(String.valueOf(value)); + } + private static ComparisonExpression equals(Expression expression1, Expression expression2) { return new ComparisonExpression(EQUAL, expression1, expression2); diff --git a/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java b/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java index f710344dfd6d..c0986e30f41b 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java +++ b/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java @@ -17,8 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.type.DoubleType; -import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import java.util.Objects; @@ -27,8 +25,7 @@ import static io.prestosql.spi.predicate.Utils.blockToNativeValue; import static io.prestosql.spi.predicate.Utils.nativeValueToBlock; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -73,11 +70,8 @@ public Marker( if (valueBlock.isPresent() && valueBlock.get().getPositionCount() != 1) { throw new IllegalArgumentException("value block should only have one position"); } - if (type instanceof RealType && valueBlock.isPresent() && Float.isNaN(intBitsToFloat(toIntExact((long) blockToNativeValue(type, valueBlock.get()))))) { - throw new IllegalArgumentException("cannot use Real NaN as range bound"); - } - if (type instanceof DoubleType && valueBlock.isPresent() && Double.isNaN((double) blockToNativeValue(type, valueBlock.get()))) { - throw new IllegalArgumentException("cannot use Double NaN as range bound"); + if (valueBlock.isPresent() && isFloatingPointNaN(type, blockToNativeValue(type, valueBlock.get()))) { + throw new IllegalArgumentException("cannot use NaN as range bound"); } this.type = type; this.valueBlock = valueBlock; diff --git a/presto-spi/src/main/java/io/prestosql/spi/predicate/SortedRangeSet.java b/presto-spi/src/main/java/io/prestosql/spi/predicate/SortedRangeSet.java index 381997e7fe23..4e3fa2fdb92e 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/predicate/SortedRangeSet.java +++ b/presto-spi/src/main/java/io/prestosql/spi/predicate/SortedRangeSet.java @@ -32,6 +32,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static io.prestosql.spi.type.TypeUtils.isFloatingPointNaN; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; @@ -181,6 +182,9 @@ public List getDiscreteSet() @Override public boolean containsValue(Object value) { + if (isFloatingPointNaN(type, value)) { + return isAll(); + } return includesMarker(Marker.exactly(type, value)); } diff --git a/presto-spi/src/main/java/io/prestosql/spi/type/TypeUtils.java b/presto-spi/src/main/java/io/prestosql/spi/type/TypeUtils.java index ae34a126d6a3..d2dd5818dc5b 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/type/TypeUtils.java +++ b/presto-spi/src/main/java/io/prestosql/spi/type/TypeUtils.java @@ -20,6 +20,11 @@ import io.prestosql.spi.block.BlockBuilder; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; public final class TypeUtils { @@ -89,6 +94,20 @@ else if (value instanceof String) { } } + public static boolean isFloatingPointNaN(Type type, Object value) + { + requireNonNull(type, "type is null"); + requireNonNull(value, "value is null"); + + if (type == DOUBLE) { + return Double.isNaN((double) value); + } + if (type == REAL) { + return Float.isNaN(intBitsToFloat(toIntExact((long) value))); + } + return false; + } + static long hashPosition(Type type, Block block, int position) { if (block.isNull(position)) { diff --git a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestDomain.java b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestDomain.java index a514267d71d8..07bb13e370bb 100644 --- a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestDomain.java +++ b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestDomain.java @@ -30,8 +30,11 @@ import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.HyperLogLogType.HYPER_LOG_LOG; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.TestingIdType.ID; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static java.lang.Double.longBitsToDouble; +import static java.lang.Float.floatToRawIntBits; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -108,6 +111,46 @@ public void testOrderableAll() assertEquals(domain.complement(), Domain.none(BIGINT)); } + @Test + public void testFloatingPointOrderableAll() + { + Domain domain = Domain.all(REAL); + assertFalse(domain.isNone()); + assertTrue(domain.isAll()); + assertFalse(domain.isSingleValue()); + assertFalse(domain.isNullableSingleValue()); + assertFalse(domain.isOnlyNull()); + assertTrue(domain.isNullAllowed()); + assertEquals(domain.getValues(), ValueSet.all(REAL)); + assertEquals(domain.getType(), REAL); + assertTrue(domain.includesNullableValue((long) floatToRawIntBits(-Float.MAX_VALUE))); + assertTrue(domain.includesNullableValue((long) floatToRawIntBits(0.0f))); + assertTrue(domain.includesNullableValue((long) floatToRawIntBits(Float.MAX_VALUE))); + assertTrue(domain.includesNullableValue((long) floatToRawIntBits(Float.MIN_VALUE))); + assertTrue(domain.includesNullableValue(null)); + assertTrue(domain.includesNullableValue((long) floatToRawIntBits(Float.NaN))); + assertTrue(domain.includesNullableValue((long) 0x7fc01234)); // different NaN representation + assertEquals(domain.complement(), Domain.none(REAL)); + + domain = Domain.all(DOUBLE); + assertFalse(domain.isNone()); + assertTrue(domain.isAll()); + assertFalse(domain.isSingleValue()); + assertFalse(domain.isNullableSingleValue()); + assertFalse(domain.isOnlyNull()); + assertTrue(domain.isNullAllowed()); + assertEquals(domain.getValues(), ValueSet.all(DOUBLE)); + assertEquals(domain.getType(), DOUBLE); + assertTrue(domain.includesNullableValue(-Double.MAX_VALUE)); + assertTrue(domain.includesNullableValue(0.0)); + assertTrue(domain.includesNullableValue(Double.MAX_VALUE)); + assertTrue(domain.includesNullableValue(Double.MIN_VALUE)); + assertTrue(domain.includesNullableValue(null)); + assertTrue(domain.includesNullableValue(Double.NaN)); + assertTrue(domain.includesNullableValue(longBitsToDouble(0x7ff8123412341234L))); // different NaN representation + assertEquals(domain.complement(), Domain.none(DOUBLE)); + } + @Test public void testEquatableAll() { diff --git a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestSortedRangeSet.java b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestSortedRangeSet.java index eccf26337b62..eb16c3819b40 100644 --- a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestSortedRangeSet.java +++ b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestSortedRangeSet.java @@ -24,17 +24,23 @@ import io.prestosql.spi.type.TestingTypeDeserializer; import io.prestosql.spi.type.TestingTypeManager; import io.prestosql.spi.type.Type; +import org.assertj.core.api.AssertProvider; import org.testng.annotations.Test; import static io.airlift.slice.Slices.utf8Slice; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class TestSortedRangeSet { @@ -166,23 +172,17 @@ public void testUnboundedSet() public void testGetSingleValue() { assertEquals(SortedRangeSet.of(BIGINT, 0L).getSingleValue(), 0L); - try { - SortedRangeSet.all(BIGINT).getSingleValue(); - fail(); - } - catch (IllegalStateException e) { - } + assertThatThrownBy(() -> SortedRangeSet.all(BIGINT).getSingleValue()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("SortedRangeSet does not have just a single value"); } @Test public void testSpan() { - try { - SortedRangeSet.none(BIGINT).getSpan(); - fail(); - } - catch (IllegalStateException e) { - } + assertThatThrownBy(() -> SortedRangeSet.none(BIGINT).getSpan()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Cannot get span if no ranges exist"); assertEquals(SortedRangeSet.all(BIGINT).getSpan(), Range.all(BIGINT)); assertEquals(SortedRangeSet.of(BIGINT, 0L).getSpan(), Range.equal(BIGINT, 0L)); @@ -253,6 +253,72 @@ public void testContains() assertFalse(SortedRangeSet.of(Range.lessThan(BIGINT, 0L)).contains(SortedRangeSet.of(Range.greaterThan(BIGINT, 0L)))); } + @Test + public void testContainsValue() + { + // BIGINT all + assertSortedRangeSet(SortedRangeSet.all(BIGINT)) + .containsValue(Long.MIN_VALUE) + .containsValue(0L) + .containsValue(42L) + .containsValue(Long.MAX_VALUE); + + // BIGINT range + assertSortedRangeSet(SortedRangeSet.of(Range.range(BIGINT, 10L, true, 41L, true))) + .doesNotContainValue(9L) + .containsValue(10L) + .containsValue(11L) + .containsValue(30L) + .containsValue(41L) + .doesNotContainValue(42L); + + assertSortedRangeSet(SortedRangeSet.of(Range.range(BIGINT, 10L, false, 41L, false))) + .doesNotContainValue(10L) + .containsValue(11L) + .containsValue(40L) + .doesNotContainValue(41L); + + // REAL all + assertSortedRangeSet(SortedRangeSet.all(REAL)) + .containsValue((long) floatToRawIntBits(42.0f)) + .containsValue((long) floatToRawIntBits(Float.NaN)); + + // REAL range + assertSortedRangeSet(SortedRangeSet.of(Range.range(REAL, (long) floatToRawIntBits(10.0f), true, (long) floatToRawIntBits(41.0f), true))) + .doesNotContainValue((long) floatToRawIntBits(9.999999f)) + .containsValue((long) floatToRawIntBits(10.0f)) + .containsValue((long) floatToRawIntBits(41.0f)) + .doesNotContainValue((long) floatToRawIntBits(41.00001f)) + .doesNotContainValue((long) floatToRawIntBits(Float.NaN)); + + assertSortedRangeSet(SortedRangeSet.of(Range.range(REAL, (long) floatToRawIntBits(10.0f), false, (long) floatToRawIntBits(41.0f), false))) + .doesNotContainValue((long) floatToRawIntBits(10.0f)) + .containsValue((long) floatToRawIntBits(10.00001f)) + .containsValue((long) floatToRawIntBits(40.99999f)) + .doesNotContainValue((long) floatToRawIntBits(41.0f)) + .doesNotContainValue((long) floatToRawIntBits(Float.NaN)); + + // DOUBLE all + assertSortedRangeSet(SortedRangeSet.all(DOUBLE)) + .containsValue(42.0) + .containsValue(Double.NaN); + + // DOUBLE range + assertSortedRangeSet(SortedRangeSet.of(Range.range(DOUBLE, 10.0, true, 41.0, true))) + .doesNotContainValue(9.999999999999999) + .containsValue(10.0) + .containsValue(41.0) + .doesNotContainValue(41.00000000000001) + .doesNotContainValue(Double.NaN); + + assertSortedRangeSet(SortedRangeSet.of(Range.range(DOUBLE, 10.0, false, 41.0, false))) + .doesNotContainValue(10.0) + .containsValue(10.00000000000001) + .containsValue(40.99999999999999) + .doesNotContainValue(41.0) + .doesNotContainValue(Double.NaN); + } + @Test public void testIntersect() { @@ -442,4 +508,35 @@ private void assertUnion(SortedRangeSet first, SortedRangeSet second, SortedRang assertEquals(first.union(second), expected); assertEquals(first.union(ImmutableList.of(first, second)), expected); } + + private static SortedRangeSetAssert assertSortedRangeSet(SortedRangeSet sortedRangeSet) + { + return assertThat((AssertProvider) () -> new SortedRangeSetAssert(sortedRangeSet)); + } + + private static class SortedRangeSetAssert + { + private final SortedRangeSet sortedRangeSet; + + public SortedRangeSetAssert(SortedRangeSet sortedRangeSet) + { + this.sortedRangeSet = requireNonNull(sortedRangeSet, "sortedRangeSet is null"); + } + + public SortedRangeSetAssert containsValue(Object value) + { + if (!sortedRangeSet.containsValue(value)) { + throw new AssertionError(format("Expected %s to contain %s", sortedRangeSet, value)); + } + return this; + } + + public SortedRangeSetAssert doesNotContainValue(Object value) + { + if (sortedRangeSet.containsValue(value)) { + throw new AssertionError(format("Expected %s not to contain %s", sortedRangeSet, value)); + } + return this; + } + } }