diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java index 47179516a388..cfc370d4e55c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/DomainTranslator.java @@ -31,6 +31,8 @@ import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.predicate.Utils; import io.prestosql.spi.predicate.ValueSet; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import io.prestosql.sql.ExpressionUtils; import io.prestosql.sql.InterpretedFunctionInvoker; @@ -77,6 +79,8 @@ import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static io.prestosql.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; import static java.util.stream.Collectors.toList; @@ -368,9 +372,30 @@ protected ExtractionResult visitLogicalBinaryExpression(LogicalBinaryExpression && leftTupleDomain.getDomains().get().keySet().equals(rightTupleDomain.getDomains().get().keySet()); boolean oneSideIsSuperSet = leftTupleDomain.contains(rightTupleDomain) || rightTupleDomain.contains(leftTupleDomain); - if (matchingSingleSymbolDomains || oneSideIsSuperSet) { + if (oneSideIsSuperSet) { remainingExpression = leftResult.getRemainingExpression(); } + else if (matchingSingleSymbolDomains) { + // Types REAL and DOUBLE require special handling because they include NaN value. In this case, we cannot rely on the union of domains. + // That is because domains covering the value set partially might union up to a domain covering the whole value set. + // While the component domains didn't include NaN, the resulting domain could be further translated to predicate "TRUE" or "a IS NOT NULL", + // which is satisfied by NaN. So during domain union, NaN might be implicitly added. + // Example: Let 'a' be a column of type DOUBLE. + // Let left TupleDomain => (a > 0) /false for NaN/, right TupleDomain => (a < 10) /false for NaN/. + // Unioned TupleDomain => "is not null" /true for NaN/ + // To guard against wrong results, the current node is returned as the remainingExpression. + Domain leftDomain = getOnlyElement(leftTupleDomain.getDomains().get().values()); + Domain rightDomain = getOnlyElement(rightTupleDomain.getDomains().get().values()); + Domain unionedDomain = getOnlyElement(columnUnionedTupleDomain.getDomains().get().values()); + Type type = leftDomain.getType(); + boolean implicitlyAddedNaN = (type instanceof RealType || type instanceof DoubleType) && + !leftDomain.getValues().isAll() && + !rightDomain.getValues().isAll() && + unionedDomain.getValues().isAll(); + if (!implicitlyAddedNaN) { + remainingExpression = leftResult.getRemainingExpression(); + } + } } return new ExtractionResult(columnUnionedTupleDomain, remainingExpression); @@ -400,7 +425,8 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, Symbol symbol = Symbol.from(symbolExpression); NullableValue value = normalized.getValue(); Type type = value.getType(); // common type for symbol and value - return createComparisonExtractionResult(normalized.getComparisonOperator(), symbol, type, value.getValue(), complement); + return createComparisonExtractionResult(normalized.getComparisonOperator(), symbol, type, value.getValue(), complement) + .orElseGet(() -> super.visitComparisonExpression(node, complement)); } if (symbolExpression instanceof Cast) { Cast castExpression = (Cast) symbolExpression; @@ -492,7 +518,7 @@ private Map, Type> analyzeExpression(Expression expression) return typeAnalyzer.getTypes(session, types, expression); } - private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) + private static Optional createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) { if (value == null) { switch (comparisonOperator) { @@ -502,54 +528,141 @@ private static ExtractionResult createComparisonExtractionResult(ComparisonExpre case LESS_THAN: case LESS_THAN_OR_EQUAL: case NOT_EQUAL: - return new ExtractionResult(TupleDomain.none(), TRUE_LITERAL); + return Optional.of(new ExtractionResult(TupleDomain.none(), TRUE_LITERAL)); case IS_DISTINCT_FROM: Domain domain = complementIfNecessary(Domain.notNull(type), complement); - return new ExtractionResult( + return Optional.of(new ExtractionResult( TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL); + TRUE_LITERAL)); default: throw new AssertionError("Unhandled operator: " + comparisonOperator); } } - - Domain domain; if (type.isOrderable()) { - domain = extractOrderableDomain(comparisonOperator, type, value, complement); + return extractOrderableDomain(comparisonOperator, type, value, complement) + .map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), TRUE_LITERAL)); } - else if (type.isComparable()) { - domain = extractEquatableDomain(comparisonOperator, type, value, complement); - } - else { - throw new AssertionError("Type cannot be used in a comparison expression (should have been caught in analysis): " + type); + if (type.isComparable()) { + Domain domain = extractEquatableDomain(comparisonOperator, type, value, complement); + return Optional.of(new ExtractionResult( + TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), + TRUE_LITERAL)); } - - return new ExtractionResult( - TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL); + throw new AssertionError("Type cannot be used in a comparison expression (should have been caught in analysis): " + type); } - private static Domain extractOrderableDomain(ComparisonExpression.Operator comparisonOperator, Type type, Object value, boolean complement) + private static Optional extractOrderableDomain(ComparisonExpression.Operator comparisonOperator, Type type, Object value, boolean complement) { checkArgument(value != null); + + // Handle orderable types which do not have NaN. + if (!(type instanceof DoubleType) && !(type instanceof RealType)) { + switch (comparisonOperator) { + case EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.equal(type, value)), complement), false)); + case GREATER_THAN: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.greaterThan(type, value)), complement), false)); + case GREATER_THAN_OR_EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.greaterThanOrEqual(type, value)), complement), false)); + case LESS_THAN: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThan(type, value)), complement), false)); + case LESS_THAN_OR_EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThanOrEqual(type, value)), complement), false)); + case NOT_EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThan(type, value), Range.greaterThan(type, value)), complement), false)); + case IS_DISTINCT_FROM: + // Need to potential complement the whole domain for IS_DISTINCT_FROM since it is null-aware + return Optional.of(complementIfNecessary(Domain.create(ValueSet.ofRanges(Range.lessThan(type, value), Range.greaterThan(type, value)), true), complement)); + default: + throw new AssertionError("Unhandled operator: " + comparisonOperator); + } + } + + // Handle comparisons against NaN + if ((type instanceof DoubleType && Double.isNaN((double) value)) || + (type instanceof RealType && Float.isNaN(intBitsToFloat(toIntExact((long) value))))) { + switch (comparisonOperator) { + case EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.none(type), complement), false)); + + case NOT_EQUAL: + return Optional.of(Domain.create(complementIfNecessary(ValueSet.all(type), complement), false)); + + case IS_DISTINCT_FROM: + // The Domain should be "all but NaN". It is currently not supported. + return Optional.empty(); + + default: + throw new AssertionError("Unhandled operator: " + comparisonOperator); + } + } + + // Handle comparisons against a non-NaN value when the compared value might be NaN switch (comparisonOperator) { + /* + For comparison operators: EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, + the Domain should not contain NaN, but complemented Domain should contain NaN. It is currently not supported. + Currently, NaN is only included when ValueSet.isAll(). + + For comparison operators: NOT_EQUAL, IS_DISTINCT_FROM, + the Domain should consist of ranges (which do not sum to the whole ValueSet), and NaN. + Currently, NaN is only included when ValueSet.isAll(). + */ case EQUAL: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.equal(type, value)), complement), false); + if (complement) { + return Optional.empty(); + } + else { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.equal(type, value)), false)); + } case GREATER_THAN: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.greaterThan(type, value)), complement), false); + if (complement) { + return Optional.empty(); + } + else { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.greaterThan(type, value)), false)); + } case GREATER_THAN_OR_EQUAL: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.greaterThanOrEqual(type, value)), complement), false); + if (complement) { + return Optional.empty(); + } + else { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(type, value)), false)); + } case LESS_THAN: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThan(type, value)), complement), false); + if (complement) { + return Optional.empty(); + } + else { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThan(type, value)), false)); + } case LESS_THAN_OR_EQUAL: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThanOrEqual(type, value)), complement), false); + if (complement) { + return Optional.empty(); + } + else { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(type, value)), false)); + } case NOT_EQUAL: - return Domain.create(complementIfNecessary(ValueSet.ofRanges(Range.lessThan(type, value), Range.greaterThan(type, value)), complement), false); + if (complement) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.equal(type, value)), false)); + } + else { + return Optional.empty(); + } case IS_DISTINCT_FROM: - // Need to potential complement the whole domain for IS_DISTINCT_FROM since it is null-aware - return complementIfNecessary(Domain.create(ValueSet.ofRanges(Range.lessThan(type, value), Range.greaterThan(type, value)), true), complement); + if (complement) { + return Optional.of(Domain.create(ValueSet.ofRanges(Range.equal(type, value)), false)); + } + else { + return Optional.empty(); + } default: throw new AssertionError("Unhandled operator: " + comparisonOperator); } @@ -631,8 +744,8 @@ private Expression rewriteComparisonExpression( return new ComparisonExpression(EQUAL, symbolExpression, coercedLiteral); } // Return something that is false for all non-null values - return and(new ComparisonExpression(EQUAL, symbolExpression, coercedLiteral), - new ComparisonExpression(NOT_EQUAL, symbolExpression, coercedLiteral)); + return and(new ComparisonExpression(GREATER_THAN, symbolExpression, coercedLiteral), + new ComparisonExpression(LESS_THAN, symbolExpression, coercedLiteral)); } case NOT_EQUAL: { if (coercedValueIsEqualToOriginal) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java index 0c0d9964f9bf..d7817ea729a4 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDomainTranslator.java @@ -24,6 +24,8 @@ import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.predicate.ValueSet; import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import io.prestosql.sql.planner.DomainTranslator.ExtractionResult; import io.prestosql.sql.tree.BetweenPredicate; @@ -91,6 +93,7 @@ import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -411,6 +414,23 @@ public void testFromOrPredicate() and(equal(C_BIGINT, bigintLiteral(1L)), unprocessableExpression1(C_BIGINT)), and(equal(C_DOUBLE, doubleLiteral(2.0)), unprocessableExpression1(C_BIGINT)))); + // Domain union implicitly adds NaN as an accepted value + // The original predicate is returned as the RemainingExpression + // (even though left and right unprocessableExpressions are the same) + originalPredicate = or( + and(greaterThan(C_DOUBLE, doubleLiteral(2.0)), unprocessableExpression1(C_DOUBLE)), + and(lessThan(C_DOUBLE, doubleLiteral(5.0)), unprocessableExpression1(C_DOUBLE))); + result = fromPredicate(originalPredicate); + assertEquals(result.getRemainingExpression(), originalPredicate); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + + originalPredicate = or( + and(greaterThan(C_REAL, realLiteral("2.0")), unprocessableExpression1(C_REAL)), + and(lessThan(C_REAL, realLiteral("5.0")), unprocessableExpression1(C_REAL))); + result = fromPredicate(originalPredicate); + assertEquals(result.getRemainingExpression(), originalPredicate); + assertEquals(result.getTupleDomain(), withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + // We can make another optimization if one side is the super set of the other side originalPredicate = or( and(greaterThan(C_BIGINT, bigintLiteral(1L)), greaterThan(C_DOUBLE, doubleLiteral(1.0)), unprocessableExpression1(C_BIGINT)), @@ -643,6 +663,46 @@ public void testFromBasicComparisonsWithNulls() withColumnDomains(ImmutableMap.of(C_COLOR, Domain.onlyNull(COLOR)))); } + @Test + public void testFromBasicComparisonsWithNaN() + { + Expression nanDouble = literalEncoder.toExpression(Double.NaN, DOUBLE); + + assertPredicateIsAlwaysFalse(equal(C_DOUBLE, nanDouble)); + assertPredicateIsAlwaysFalse(greaterThan(C_DOUBLE, nanDouble)); + assertPredicateIsAlwaysFalse(greaterThanOrEqual(C_DOUBLE, nanDouble)); + assertPredicateIsAlwaysFalse(lessThan(C_DOUBLE, nanDouble)); + assertPredicateIsAlwaysFalse(lessThanOrEqual(C_DOUBLE, nanDouble)); + assertPredicateTranslates(notEqual(C_DOUBLE, nanDouble), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertUnsupportedPredicate(isDistinctFrom(C_DOUBLE, nanDouble)); + + assertPredicateTranslates(not(equal(C_DOUBLE, nanDouble)), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertPredicateTranslates(not(greaterThan(C_DOUBLE, nanDouble)), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertPredicateTranslates(not(greaterThanOrEqual(C_DOUBLE, nanDouble)), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertPredicateTranslates(not(lessThan(C_DOUBLE, nanDouble)), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertPredicateTranslates(not(lessThanOrEqual(C_DOUBLE, nanDouble)), TupleDomain.withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.notNull(DOUBLE)))); + assertPredicateIsAlwaysFalse(not(notEqual(C_DOUBLE, nanDouble))); + assertUnsupportedPredicate(not(isDistinctFrom(C_DOUBLE, nanDouble))); + + Expression nanReal = literalEncoder.toExpression((long) Float.floatToIntBits(Float.NaN), REAL); + + assertPredicateIsAlwaysFalse(equal(C_REAL, nanReal)); + assertPredicateIsAlwaysFalse(greaterThan(C_REAL, nanReal)); + assertPredicateIsAlwaysFalse(greaterThanOrEqual(C_REAL, nanReal)); + assertPredicateIsAlwaysFalse(lessThan(C_REAL, nanReal)); + assertPredicateIsAlwaysFalse(lessThanOrEqual(C_REAL, nanReal)); + assertPredicateTranslates(notEqual(C_REAL, nanReal), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertUnsupportedPredicate(isDistinctFrom(C_REAL, nanReal)); + + assertPredicateTranslates(not(equal(C_REAL, nanReal)), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertPredicateTranslates(not(greaterThan(C_REAL, nanReal)), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertPredicateTranslates(not(greaterThanOrEqual(C_REAL, nanReal)), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertPredicateTranslates(not(lessThan(C_REAL, nanReal)), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertPredicateTranslates(not(lessThanOrEqual(C_REAL, nanReal)), TupleDomain.withColumnDomains(ImmutableMap.of(C_REAL, Domain.notNull(REAL)))); + assertPredicateIsAlwaysFalse(not(notEqual(C_REAL, nanReal))); + assertUnsupportedPredicate(not(isDistinctFrom(C_REAL, nanReal))); + } + @Test public void testNonImplicitCastOnSymbolSide() { @@ -748,8 +808,8 @@ public void testFromComparisonsWithCoercions() // B is a double column. Check that it can be compared against longs assertPredicateTranslates( - not(greaterThan(C_DOUBLE, cast(bigintLiteral(2L), DOUBLE))), - withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(DOUBLE, 2.0)), false)))); + greaterThan(C_DOUBLE, cast(bigintLiteral(2L), DOUBLE)), + withColumnDomains(ImmutableMap.of(C_DOUBLE, Domain.create(ValueSet.ofRanges(Range.greaterThan(DOUBLE, 2.0)), false)))); // C is a string column. Check that it can be compared. assertPredicateTranslates( @@ -1107,6 +1167,26 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa testSimpleComparison(greaterThanOrEqual(columnExpression, fractionalNegative), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalNegative())); } + // greater than or equal negated + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, integerPositive))); + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, integerNegative))); + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, max))); + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, min))); + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(greaterThanOrEqual(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(greaterThanOrEqual(columnExpression, integerPositive)), columnSymbol, Range.lessThan(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(greaterThanOrEqual(columnExpression, integerNegative)), columnSymbol, Range.lessThan(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(greaterThanOrEqual(columnExpression, max)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getMax())); + testSimpleComparison(not(greaterThanOrEqual(columnExpression, min)), columnSymbol, Range.lessThan(columnType, columnValues.getMin())); + if (literalValues.isFractional()) { + testSimpleComparison(not(greaterThanOrEqual(columnExpression, fractionalPositive)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalPositive())); + testSimpleComparison(not(greaterThanOrEqual(columnExpression, fractionalNegative)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalNegative())); + } + } + // greater than testSimpleComparison(greaterThan(columnExpression, integerPositive), columnSymbol, Range.greaterThan(columnType, columnValues.getIntegerPositive())); testSimpleComparison(greaterThan(columnExpression, integerNegative), columnSymbol, Range.greaterThan(columnType, columnValues.getIntegerNegative())); @@ -1117,6 +1197,26 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa testSimpleComparison(greaterThan(columnExpression, fractionalNegative), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalNegative())); } + // greater than negated + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(not(greaterThan(columnExpression, integerPositive))); + assertNoFullPushdown(not(greaterThan(columnExpression, integerNegative))); + assertNoFullPushdown(not(greaterThan(columnExpression, max))); + assertNoFullPushdown(not(greaterThan(columnExpression, min))); + assertNoFullPushdown(not(greaterThan(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(greaterThan(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(greaterThan(columnExpression, integerPositive)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(greaterThan(columnExpression, integerNegative)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(greaterThan(columnExpression, max)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getMax())); + testSimpleComparison(not(greaterThan(columnExpression, min)), columnSymbol, Range.lessThan(columnType, columnValues.getMin())); + if (literalValues.isFractional()) { + testSimpleComparison(not(greaterThan(columnExpression, fractionalPositive)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalPositive())); + testSimpleComparison(not(greaterThan(columnExpression, fractionalNegative)), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalNegative())); + } + } + // less than or equal testSimpleComparison(lessThanOrEqual(columnExpression, integerPositive), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getIntegerPositive())); testSimpleComparison(lessThanOrEqual(columnExpression, integerNegative), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getIntegerNegative())); @@ -1127,6 +1227,26 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa testSimpleComparison(lessThanOrEqual(columnExpression, fractionalNegative), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalNegative())); } + // less than or equal negated + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, integerPositive))); + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, integerNegative))); + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, max))); + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, min))); + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(lessThanOrEqual(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(lessThanOrEqual(columnExpression, integerPositive)), columnSymbol, Range.greaterThan(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(lessThanOrEqual(columnExpression, integerNegative)), columnSymbol, Range.greaterThan(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(lessThanOrEqual(columnExpression, max)), columnSymbol, Range.greaterThan(columnType, columnValues.getMax())); + testSimpleComparison(not(lessThanOrEqual(columnExpression, min)), columnSymbol, Range.greaterThanOrEqual(columnType, columnValues.getMin())); + if (literalValues.isFractional()) { + testSimpleComparison(not(lessThanOrEqual(columnExpression, fractionalPositive)), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalPositive())); + testSimpleComparison(not(lessThanOrEqual(columnExpression, fractionalNegative)), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalNegative())); + } + } + // less than testSimpleComparison(lessThan(columnExpression, integerPositive), columnSymbol, Range.lessThan(columnType, columnValues.getIntegerPositive())); testSimpleComparison(lessThan(columnExpression, integerNegative), columnSymbol, Range.lessThan(columnType, columnValues.getIntegerNegative())); @@ -1137,6 +1257,26 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa testSimpleComparison(lessThan(columnExpression, fractionalNegative), columnSymbol, Range.lessThanOrEqual(columnType, columnValues.getFractionalNegative())); } + // less than negated + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(not(lessThan(columnExpression, integerPositive))); + assertNoFullPushdown(not(lessThan(columnExpression, integerNegative))); + assertNoFullPushdown(not(lessThan(columnExpression, max))); + assertNoFullPushdown(not(lessThan(columnExpression, min))); + assertNoFullPushdown(not(lessThan(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(lessThan(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(lessThan(columnExpression, integerPositive)), columnSymbol, Range.greaterThanOrEqual(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(lessThan(columnExpression, integerNegative)), columnSymbol, Range.greaterThanOrEqual(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(lessThan(columnExpression, max)), columnSymbol, Range.greaterThan(columnType, columnValues.getMax())); + testSimpleComparison(not(lessThan(columnExpression, min)), columnSymbol, Range.greaterThanOrEqual(columnType, columnValues.getMin())); + if (literalValues.isFractional()) { + testSimpleComparison(not(lessThan(columnExpression, fractionalPositive)), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalPositive())); + testSimpleComparison(not(lessThan(columnExpression, fractionalNegative)), columnSymbol, Range.greaterThan(columnType, columnValues.getFractionalNegative())); + } + } + // equal testSimpleComparison(equal(columnExpression, integerPositive), columnSymbol, Range.equal(columnType, columnValues.getIntegerPositive())); testSimpleComparison(equal(columnExpression, integerNegative), columnSymbol, Range.equal(columnType, columnValues.getIntegerNegative())); @@ -1147,25 +1287,95 @@ private void testNumericTypeTranslation(NumericValues columnValues, NumericVa testSimpleComparison(equal(columnExpression, fractionalNegative), columnSymbol, Domain.none(columnType)); } + // equal negated + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(not(equal(columnExpression, integerPositive))); + assertNoFullPushdown(not(equal(columnExpression, integerNegative))); + assertNoFullPushdown(not(equal(columnExpression, max))); + assertNoFullPushdown(not(equal(columnExpression, min))); + assertNoFullPushdown(not(equal(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(equal(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(equal(columnExpression, integerPositive)), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerPositive()), Range.greaterThan(columnType, columnValues.getIntegerPositive())), false)); + testSimpleComparison(not(equal(columnExpression, integerNegative)), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerNegative()), Range.greaterThan(columnType, columnValues.getIntegerNegative())), false)); + testSimpleComparison(not(equal(columnExpression, max)), columnSymbol, Domain.notNull(columnType)); + testSimpleComparison(not(equal(columnExpression, min)), columnSymbol, Domain.notNull(columnType)); + if (literalValues.isFractional()) { + testSimpleComparison(not(equal(columnExpression, fractionalPositive)), columnSymbol, Domain.notNull(columnType)); + testSimpleComparison(not(equal(columnExpression, fractionalNegative)), columnSymbol, Domain.notNull(columnType)); + } + } + // not equal - testSimpleComparison(notEqual(columnExpression, integerPositive), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerPositive()), Range.greaterThan(columnType, columnValues.getIntegerPositive())), false)); - testSimpleComparison(notEqual(columnExpression, integerNegative), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerNegative()), Range.greaterThan(columnType, columnValues.getIntegerNegative())), false)); - testSimpleComparison(notEqual(columnExpression, max), columnSymbol, Domain.notNull(columnType)); - testSimpleComparison(notEqual(columnExpression, min), columnSymbol, Domain.notNull(columnType)); - if (literalValues.isFractional()) { - testSimpleComparison(notEqual(columnExpression, fractionalPositive), columnSymbol, Domain.notNull(columnType)); - testSimpleComparison(notEqual(columnExpression, fractionalNegative), columnSymbol, Domain.notNull(columnType)); + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(notEqual(columnExpression, integerPositive)); + assertNoFullPushdown(notEqual(columnExpression, integerNegative)); + assertNoFullPushdown(notEqual(columnExpression, max)); + assertNoFullPushdown(notEqual(columnExpression, min)); + assertNoFullPushdown(notEqual(columnExpression, fractionalPositive)); + assertNoFullPushdown(notEqual(columnExpression, integerNegative)); + } + else { + testSimpleComparison(notEqual(columnExpression, integerPositive), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerPositive()), Range.greaterThan(columnType, columnValues.getIntegerPositive())), false)); + testSimpleComparison(notEqual(columnExpression, integerNegative), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerNegative()), Range.greaterThan(columnType, columnValues.getIntegerNegative())), false)); + testSimpleComparison(notEqual(columnExpression, max), columnSymbol, Domain.notNull(columnType)); + testSimpleComparison(notEqual(columnExpression, min), columnSymbol, Domain.notNull(columnType)); + if (literalValues.isFractional()) { + testSimpleComparison(notEqual(columnExpression, fractionalPositive), columnSymbol, Domain.notNull(columnType)); + testSimpleComparison(notEqual(columnExpression, fractionalNegative), columnSymbol, Domain.notNull(columnType)); + } + } + + // not equal negated + if (literalValues.isTypeWithNaN()) { + testSimpleComparison(not(notEqual(columnExpression, integerPositive)), columnSymbol, Range.equal(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(notEqual(columnExpression, integerNegative)), columnSymbol, Range.equal(columnType, columnValues.getIntegerNegative())); + assertNoFullPushdown(not(notEqual(columnExpression, max))); + assertNoFullPushdown(not(notEqual(columnExpression, min))); + assertNoFullPushdown(not(notEqual(columnExpression, fractionalPositive))); + assertNoFullPushdown(not(notEqual(columnExpression, fractionalNegative))); + } + else { + testSimpleComparison(not(notEqual(columnExpression, integerPositive)), columnSymbol, Range.equal(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(notEqual(columnExpression, integerNegative)), columnSymbol, Range.equal(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(notEqual(columnExpression, max)), columnSymbol, Domain.none(columnType)); + testSimpleComparison(not(notEqual(columnExpression, min)), columnSymbol, Domain.none(columnType)); + if (literalValues.isFractional()) { + testSimpleComparison(not(notEqual(columnExpression, fractionalPositive)), columnSymbol, Domain.none(columnType)); + testSimpleComparison(not(notEqual(columnExpression, fractionalNegative)), columnSymbol, Domain.none(columnType)); + } } // is distinct from - testSimpleComparison(isDistinctFrom(columnExpression, integerPositive), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerPositive()), Range.greaterThan(columnType, columnValues.getIntegerPositive())), true)); - testSimpleComparison(isDistinctFrom(columnExpression, integerNegative), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerNegative()), Range.greaterThan(columnType, columnValues.getIntegerNegative())), true)); - testSimpleComparison(isDistinctFrom(columnExpression, max), columnSymbol, Domain.all(columnType)); - testSimpleComparison(isDistinctFrom(columnExpression, min), columnSymbol, Domain.all(columnType)); - if (literalValues.isFractional()) { + if (literalValues.isTypeWithNaN()) { + assertNoFullPushdown(isDistinctFrom(columnExpression, integerPositive)); + assertNoFullPushdown(isDistinctFrom(columnExpression, integerNegative)); + testSimpleComparison(isDistinctFrom(columnExpression, max), columnSymbol, Domain.all(columnType)); + testSimpleComparison(isDistinctFrom(columnExpression, min), columnSymbol, Domain.all(columnType)); testSimpleComparison(isDistinctFrom(columnExpression, fractionalPositive), columnSymbol, Domain.all(columnType)); testSimpleComparison(isDistinctFrom(columnExpression, fractionalNegative), columnSymbol, Domain.all(columnType)); } + else { + testSimpleComparison(isDistinctFrom(columnExpression, integerPositive), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerPositive()), Range.greaterThan(columnType, columnValues.getIntegerPositive())), true)); + testSimpleComparison(isDistinctFrom(columnExpression, integerNegative), columnSymbol, Domain.create(ValueSet.ofRanges(Range.lessThan(columnType, columnValues.getIntegerNegative()), Range.greaterThan(columnType, columnValues.getIntegerNegative())), true)); + testSimpleComparison(isDistinctFrom(columnExpression, max), columnSymbol, Domain.all(columnType)); + testSimpleComparison(isDistinctFrom(columnExpression, min), columnSymbol, Domain.all(columnType)); + if (literalValues.isFractional()) { + testSimpleComparison(isDistinctFrom(columnExpression, fractionalPositive), columnSymbol, Domain.all(columnType)); + testSimpleComparison(isDistinctFrom(columnExpression, fractionalNegative), columnSymbol, Domain.all(columnType)); + } + } + + // is distinct from negated + testSimpleComparison(not(isDistinctFrom(columnExpression, integerPositive)), columnSymbol, Range.equal(columnType, columnValues.getIntegerPositive())); + testSimpleComparison(not(isDistinctFrom(columnExpression, integerNegative)), columnSymbol, Range.equal(columnType, columnValues.getIntegerNegative())); + testSimpleComparison(not(isDistinctFrom(columnExpression, max)), columnSymbol, Domain.none(columnType)); + testSimpleComparison(not(isDistinctFrom(columnExpression, min)), columnSymbol, Domain.none(columnType)); + if (literalValues.isFractional()) { + testSimpleComparison(not(isDistinctFrom(columnExpression, fractionalPositive)), columnSymbol, Domain.none(columnType)); + testSimpleComparison(not(isDistinctFrom(columnExpression, fractionalNegative)), columnSymbol, Domain.none(columnType)); + } } @Test @@ -1206,6 +1416,12 @@ private void assertPredicateTranslates(Expression expression, TupleDomain 0); } + + public boolean isTypeWithNaN() + { + return type instanceof DoubleType || type instanceof RealType; + } } } diff --git a/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java b/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java index 5a2bff330344..2d9da8f3dc02 100644 --- a/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java +++ b/presto-spi/src/main/java/io/prestosql/spi/predicate/Marker.java @@ -17,11 +17,17 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.type.DoubleType; +import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import java.util.Objects; import java.util.Optional; +import static io.prestosql.spi.predicate.Utils.blockToNativeValue; +import static io.prestosql.spi.predicate.Utils.nativeValueToBlock; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -66,6 +72,12 @@ public Marker( if (valueBlock.isPresent() && valueBlock.get().getPositionCount() != 1) { throw new IllegalArgumentException("value block should only have one position"); } + if (type instanceof RealType && valueBlock.isPresent() && Float.isNaN(intBitsToFloat(toIntExact((long) blockToNativeValue(type, valueBlock.get()))))) { + throw new IllegalArgumentException("cannot use Real NaN as range bound"); + } + if (type instanceof DoubleType && valueBlock.isPresent() && Double.isNaN((double) blockToNativeValue(type, valueBlock.get()))) { + throw new IllegalArgumentException("cannot use Double NaN as range bound"); + } this.type = type; this.valueBlock = valueBlock; this.bound = bound; @@ -73,7 +85,7 @@ public Marker( private static Marker create(Type type, Optional value, Bound bound) { - return new Marker(type, value.map(object -> Utils.nativeValueToBlock(type, object)), bound); + return new Marker(type, value.map(object -> nativeValueToBlock(type, object)), bound); } public static Marker upperUnbounded(Type type) @@ -126,7 +138,7 @@ public Object getValue() if (!valueBlock.isPresent()) { throw new IllegalStateException("No value to get"); } - return Utils.blockToNativeValue(type, valueBlock.get()); + return blockToNativeValue(type, valueBlock.get()); } public Object getPrintableValue(ConnectorSession session) diff --git a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestMarker.java b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestMarker.java index c14021165d1a..114dc52bbe2b 100644 --- a/presto-spi/src/test/java/io/prestosql/spi/predicate/TestMarker.java +++ b/presto-spi/src/test/java/io/prestosql/spi/predicate/TestMarker.java @@ -33,7 +33,10 @@ import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.RealType.REAL; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static java.lang.Float.floatToIntBits; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -162,6 +165,22 @@ public void testAdjacency() } } + @Test + public void testDoubleNaN() + { + assertThatThrownBy(() -> Marker.above(DOUBLE, Double.NaN)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> Marker.exactly(DOUBLE, Double.NaN)).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> Marker.below(DOUBLE, Double.NaN)).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void testRealNaN() + { + assertThatThrownBy(() -> Marker.above(REAL, (long) floatToIntBits(Float.NaN))).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> Marker.exactly(REAL, (long) floatToIntBits(Float.NaN))).isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> Marker.below(REAL, (long) floatToIntBits(Float.NaN))).isInstanceOf(IllegalArgumentException.class); + } + @Test public void testJsonSerialization() throws Exception