From 362787518c4ac0d9eacae71b4463ed82aa2da88a Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Fri, 27 Dec 2024 16:53:45 +0100 Subject: [PATCH 1/4] Adjust unwrapCast method to not be dependent on ComparisonExpression Co-authored-by: kabunchi --- .../rule/UnwrapCastInComparison.java | 92 +++++++++---------- 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index b5a230882dd0..3f21d79d6a87 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -159,36 +159,34 @@ public Visitor(PlannerContext plannerContext, Session session) public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter treeRewriter) { Comparison expression = treeRewriter.defaultRewrite(node, null); - return unwrapCast(expression); + return unwrapCast(expression.operator(), expression.left(), expression.right()).orElse(expression); } - private Expression unwrapCast(Comparison expression) + private Optional unwrapCast(Comparison.Operator operator, Expression leftExpression, Expression rightExpression) { // Canonicalization is handled by CanonicalizeExpressionRewriter - if (!(expression.left() instanceof Cast cast)) { - return expression; + if (!(leftExpression instanceof Cast cast)) { + return Optional.empty(); } - Expression right = optimizer.process(expression.right(), session, ImmutableMap.of()).orElse(expression.right()); - - Comparison.Operator operator = expression.operator(); + Expression right = optimizer.process(rightExpression, session, ImmutableMap.of()).orElse(rightExpression); if (right instanceof Constant constant && constant.value() == null) { return switch (operator) { - case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Constant(BOOLEAN, null); - case IDENTICAL -> new IsNull(cast); + case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> Optional.of(new Constant(BOOLEAN, null)); + case IDENTICAL -> Optional.of(new IsNull(cast)); }; } if (!(right instanceof Constant(Type type, Object rightValue))) { - return expression; + return Optional.empty(); } Type sourceType = cast.expression().type(); - Type targetType = expression.right().type(); + Type targetType = rightExpression.type(); if (sourceType instanceof TimestampType && targetType == DATE) { - return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.expression(), (long) rightValue).orElse(expression); + return unwrapTimestampToDateCast((TimestampType) sourceType, operator, cast.expression(), (long) rightValue); } if (targetType instanceof TimestampWithTimeZoneType) { @@ -197,7 +195,7 @@ private Expression unwrapCast(Comparison expression) } if (!hasInjectiveImplicitCoercion(sourceType, targetType, rightValue)) { - return expression; + return Optional.empty(); } // Handle comparison against NaN. @@ -209,12 +207,12 @@ private Expression unwrapCast(Comparison expression) case GREATER_THAN_OR_EQUAL: case LESS_THAN: case LESS_THAN_OR_EQUAL: - return falseIfNotNull(cast.expression()); + return Optional.of(falseIfNotNull(cast.expression())); case NOT_EQUAL: - return trueIfNotNull(cast.expression()); + return Optional.of(trueIfNotNull(cast.expression())); case IDENTICAL: if (!typeHasNaN(sourceType)) { - return FALSE; + return Optional.of(FALSE); } // NaN on the right of comparison will be cast to source type later break; @@ -242,21 +240,21 @@ private Expression unwrapCast(Comparison expression) if (upperBoundComparison > 0) { // larger than maximum representable value return switch (operator) { - case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> falseIfNotNull(cast.expression()); - case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); - case IDENTICAL -> FALSE; + case EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> Optional.of(falseIfNotNull(cast.expression())); + case NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case IDENTICAL -> Optional.of(FALSE); }; } if (upperBoundComparison == 0) { // equal to max representable value return switch (operator) { - case GREATER_THAN -> falseIfNotNull(cast.expression()); - case GREATER_THAN_OR_EQUAL -> new Comparison(EQUAL, cast.expression(), new Constant(sourceType, max)); - case LESS_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); - case LESS_THAN -> new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, max)); + case GREATER_THAN -> Optional.of(falseIfNotNull(cast.expression())); + case GREATER_THAN_OR_EQUAL -> Optional.of(new Comparison(EQUAL, cast.expression(), new Constant(sourceType, max))); + case LESS_THAN_OR_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case LESS_THAN -> Optional.of(new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, max))); case EQUAL, NOT_EQUAL, IDENTICAL -> - new Comparison(operator, cast.expression(), new Constant(sourceType, max)); + Optional.of(new Comparison(operator, cast.expression(), new Constant(sourceType, max))); }; } @@ -267,21 +265,21 @@ private Expression unwrapCast(Comparison expression) if (lowerBoundComparison < 0) { // smaller than minimum representable value return switch (operator) { - case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); - case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> falseIfNotNull(cast.expression()); - case IDENTICAL -> FALSE; + case NOT_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL -> Optional.of(falseIfNotNull(cast.expression())); + case IDENTICAL -> Optional.of(FALSE); }; } if (lowerBoundComparison == 0) { // equal to min representable value return switch (operator) { - case LESS_THAN -> falseIfNotNull(cast.expression()); - case LESS_THAN_OR_EQUAL -> new Comparison(EQUAL, cast.expression(), new Constant(sourceType, min)); - case GREATER_THAN_OR_EQUAL -> trueIfNotNull(cast.expression()); - case GREATER_THAN -> new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, min)); + case LESS_THAN -> Optional.of(falseIfNotNull(cast.expression())); + case LESS_THAN_OR_EQUAL -> Optional.of(new Comparison(EQUAL, cast.expression(), new Constant(sourceType, min))); + case GREATER_THAN_OR_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case GREATER_THAN -> Optional.of(new Comparison(NOT_EQUAL, cast.expression(), new Constant(sourceType, min))); case EQUAL, NOT_EQUAL, IDENTICAL -> - new Comparison(operator, cast.expression(), new Constant(sourceType, min)); + Optional.of(new Comparison(operator, cast.expression(), new Constant(sourceType, min))); }; } } @@ -293,7 +291,7 @@ private Expression unwrapCast(Comparison expression) } catch (OperatorNotFoundException e) { // Without a cast between target -> source, there's nothing more we can do - return expression; + return Optional.empty(); } Object literalInSourceType; @@ -307,7 +305,7 @@ private Expression unwrapCast(Comparison expression) // 3. out of range or otherwise unrepresentable value // Since we can't distinguish between those cases, take the conservative option // and bail out. - return expression; + return Optional.empty(); } if (targetType.isOrderable()) { @@ -318,40 +316,40 @@ private Expression unwrapCast(Comparison expression) if (literalVsRoundtripped > 0) { // cast rounded down return switch (operator) { - case EQUAL -> falseIfNotNull(cast.expression()); - case NOT_EQUAL -> trueIfNotNull(cast.expression()); - case IDENTICAL -> FALSE; + case EQUAL -> Optional.of(falseIfNotNull(cast.expression())); + case NOT_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case IDENTICAL -> Optional.of(FALSE); case LESS_THAN, LESS_THAN_OR_EQUAL -> { if (sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMin(), literalInSourceType) == 0) { - yield new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); + yield Optional.of(new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType))); } - yield new Comparison(LESS_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); + yield Optional.of(new Comparison(LESS_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType))); } case GREATER_THAN, GREATER_THAN_OR_EQUAL -> // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // larger than the next value in the source type - new Comparison(GREATER_THAN, cast.expression(), new Constant(sourceType, literalInSourceType)); + Optional.of(new Comparison(GREATER_THAN, cast.expression(), new Constant(sourceType, literalInSourceType))); }; } if (literalVsRoundtripped < 0) { // cast rounded up return switch (operator) { - case EQUAL -> falseIfNotNull(cast.expression()); - case NOT_EQUAL -> trueIfNotNull(cast.expression()); - case IDENTICAL -> FALSE; + case EQUAL -> Optional.of(falseIfNotNull(cast.expression())); + case NOT_EQUAL -> Optional.of(trueIfNotNull(cast.expression())); + case IDENTICAL -> Optional.of(FALSE); case LESS_THAN, LESS_THAN_OR_EQUAL -> // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot produce a value // smaller than the next value in the source type - new Comparison(LESS_THAN, cast.expression(), new Constant(sourceType, literalInSourceType)); + Optional.of(new Comparison(LESS_THAN, cast.expression(), new Constant(sourceType, literalInSourceType))); case GREATER_THAN, GREATER_THAN_OR_EQUAL -> sourceRange.isPresent() && compare(sourceType, sourceRange.get().getMax(), literalInSourceType) == 0 ? - new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)) : - new Comparison(GREATER_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)); + Optional.of(new Comparison(EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType))) : + Optional.of(new Comparison(GREATER_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType))); }; } } - return new Comparison(operator, cast.expression(), new Constant(sourceType, literalInSourceType)); + return Optional.of(new Comparison(operator, cast.expression(), new Constant(sourceType, literalInSourceType))); } private Optional unwrapTimestampToDateCast(TimestampType sourceType, Comparison.Operator operator, Expression timestampExpression, long date) From 61c4581bc2408b34a2d50afe8301153aac7713b5 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Fri, 27 Dec 2024 17:30:23 +0100 Subject: [PATCH 2/4] Unwrap casts in BETWEEN predicate This change allows the engine to infer that, for instance, given t::timestamp(6) cast(t as date) BETWEEN DATE '2022-01-01' AND DATE '2022-01-02' can be rewritten as t BETWEEN TIMESTAMP '2022-01-01 00:00:00' AND TIMESTAMP '2022-01-02 23:59:59.999999' Range predicate `BetweenPredicate` can be transformed into a `TupleDomain` and thus help with predicate pushdown. Range-based `TupleDomain` representation is critical for connectors which have min/max-based metadata (like Iceberg manifests lists which play a key role in partition pruning or Iceberg data files), as ranges allow for intersection tests, something that is hard to do in a generic manner for `ConnectorExpression`. --- .../rule/UnwrapCastInComparison.java | 123 ++++++++++++++++ .../planner/TestUnwrapCastInComparison.java | 138 ++++++++++++++++++ .../sql/query/TestUnwrapCastInComparison.java | 88 +++++++++++ .../iceberg/BaseIcebergConnectorTest.java | 47 ++++++ 4 files changed, 396 insertions(+) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java index 3f21d79d6a87..d632484f1953 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java @@ -36,6 +36,7 @@ import io.trino.spi.type.VarcharType; import io.trino.sql.InterpretedFunctionInvoker; import io.trino.sql.PlannerContext; +import io.trino.sql.ir.Between; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -50,7 +51,9 @@ import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.time.zone.ZoneOffsetTransition; +import java.util.Objects; import java.util.Optional; +import java.util.Set; import static com.google.common.base.Verify.verify; import static io.airlift.slice.SliceUtf8.countCodePoints; @@ -79,6 +82,7 @@ import static io.trino.sql.ir.IrUtils.and; import static io.trino.sql.ir.IrUtils.or; import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer; +import static io.trino.type.UnknownType.UNKNOWN; import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -162,6 +166,125 @@ public Expression rewriteComparison(Comparison node, Void context, ExpressionTre return unwrapCast(expression.operator(), expression.left(), expression.right()).orElse(expression); } + @Override + public Expression rewriteBetween(Between node, Void context, ExpressionTreeRewriter treeRewriter) + { + Between expression = treeRewriter.defaultRewrite(node, null); + + if (!(expression.value() instanceof Cast cast)) { + return expression; + } + + Optional optionalLowBoundUnwrapped = unwrapCast(GREATER_THAN_OR_EQUAL, node.value(), node.min()); + Optional optionalHighBoundUnwrapped = unwrapCast(LESS_THAN_OR_EQUAL, node.value(), node.max()); + if (optionalLowBoundUnwrapped.isEmpty() || optionalHighBoundUnwrapped.isEmpty()) { + return expression; + } + Expression lowBoundUnwrapped = optionalLowBoundUnwrapped.get(); + Expression highBoundUnwrapped = optionalHighBoundUnwrapped.get(); + + Expression trueIfNotNullExpression = trueIfNotNull(cast.expression()); + if (trueIfNotNullExpression.equals(lowBoundUnwrapped)) { + return highBoundUnwrapped; + } + if (trueIfNotNullExpression.equals(highBoundUnwrapped)) { + return lowBoundUnwrapped; + } + + Expression falseIfNotNullExpression = falseIfNotNull(cast.expression()); + if (falseIfNotNullExpression.equals(lowBoundUnwrapped) || falseIfNotNullExpression.equals(highBoundUnwrapped)) { + if (falseIfNotNullExpression.equals(lowBoundUnwrapped) && falseIfNotNullExpression.equals(highBoundUnwrapped)) { + return falseIfNotNullExpression; + } + return and(lowBoundUnwrapped, highBoundUnwrapped); + } + Expression castNullToBoolean = new Cast(new Constant(UNKNOWN, null), BOOLEAN); + if (castNullToBoolean.equals(lowBoundUnwrapped) || castNullToBoolean.equals(highBoundUnwrapped)) { + if (castNullToBoolean.equals(lowBoundUnwrapped) && castNullToBoolean.equals(highBoundUnwrapped)) { + return castNullToBoolean; + } + return and(lowBoundUnwrapped, highBoundUnwrapped); + } + + if (lowBoundUnwrapped instanceof Comparison lowBoundUnwrappedComparison && + highBoundUnwrapped instanceof Comparison highBoundUnwrappedComparison) { + Type sourceType = cast.expression().type(); + Type lowBoundType = lowBoundUnwrappedComparison.right().type(); + Type highBoundType = highBoundUnwrappedComparison.right().type(); + + if (cast.expression().equals(lowBoundUnwrappedComparison.left()) && + Objects.equals(sourceType, lowBoundType) && + cast.expression().equals(highBoundUnwrappedComparison.left()) && + Objects.equals(sourceType, highBoundType)) { + if (Set.of(GREATER_THAN, GREATER_THAN_OR_EQUAL).contains(lowBoundUnwrappedComparison.operator()) && + Set.of(LESS_THAN, LESS_THAN_OR_EQUAL).contains(highBoundUnwrappedComparison.operator())) { + // Try to reconstruct the BETWEEN predicate with the cast unwrapped + if (!(lowBoundUnwrappedComparison.right() instanceof Constant(Type _, Object lowBoundValue)) + || !(highBoundUnwrappedComparison.right() instanceof Constant(Type _, Object highBoundValue))) { + return expression; + } + + int compareLowBoundValueAndHighBoundValue = compare(sourceType, lowBoundValue, highBoundValue); + if (compareLowBoundValueAndHighBoundValue > 0) { + // range min greater than range max + return falseIfNotNull(cast.expression()); + } + + Expression greaterThanOrEqualLowBoundUnwrappedExpression; + Optional nextAfterLowBoundValue = Optional.empty(); + if (lowBoundUnwrappedComparison.operator() == GREATER_THAN) { + nextAfterLowBoundValue = sourceType.getNextValue(lowBoundValue); + if (nextAfterLowBoundValue.isEmpty()) { + return and(lowBoundUnwrappedComparison, highBoundUnwrappedComparison); + } + greaterThanOrEqualLowBoundUnwrappedExpression = new Constant(sourceType, nextAfterLowBoundValue.get()); + } + else { + greaterThanOrEqualLowBoundUnwrappedExpression = lowBoundUnwrappedComparison.right(); + } + + Expression lessThanOrEqualHighBoundUnwrappedExpression; + Optional previousBeforeHighBoundValue = Optional.empty(); + if (highBoundUnwrappedComparison.operator() == LESS_THAN) { + previousBeforeHighBoundValue = sourceType.getPreviousValue(highBoundValue); + if (previousBeforeHighBoundValue.isEmpty()) { + return and(lowBoundUnwrappedComparison, highBoundUnwrappedComparison); + } + lessThanOrEqualHighBoundUnwrappedExpression = new Constant(sourceType, previousBeforeHighBoundValue.get()); + } + else { + lessThanOrEqualHighBoundUnwrappedExpression = highBoundUnwrappedComparison.right(); + } + + if (nextAfterLowBoundValue.isPresent() && previousBeforeHighBoundValue.isPresent()) { + int compareNextLowAndPreviousHighBound = compare(sourceType, nextAfterLowBoundValue.get(), previousBeforeHighBoundValue.get()); + if (compareNextLowAndPreviousHighBound >= 0) { + return falseIfNotNull(cast.expression()); + } + } + else if (previousBeforeHighBoundValue.isPresent()) { + int compareLowAndPreviousHighBound = compare(sourceType, lowBoundValue, previousBeforeHighBoundValue.get()); + if (compareLowAndPreviousHighBound > 0) { + return falseIfNotNull(cast.expression()); + } + } + else if (nextAfterLowBoundValue.isPresent()) { + int compareNextLowAndHighBound = compare(sourceType, nextAfterLowBoundValue.get(), highBoundValue); + if (compareNextLowAndHighBound > 0) { + return falseIfNotNull(cast.expression()); + } + } + + return new Between(cast.expression(), greaterThanOrEqualLowBoundUnwrappedExpression, lessThanOrEqualHighBoundUnwrappedExpression); + } + } + + return and(lowBoundUnwrappedComparison, highBoundUnwrappedComparison); + } + + return expression; + } + private Optional unwrapCast(Comparison.Operator operator, Expression leftExpression, Expression rightExpression) { // Canonicalization is handled by CanonicalizeExpressionRewriter diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java index 6da20c634beb..7d21ab82d251 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestUnwrapCastInComparison.java @@ -18,8 +18,11 @@ import io.trino.Session; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.type.Decimals; import io.trino.spi.type.TimeZoneKey; +import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; +import io.trino.sql.ir.Case; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; @@ -28,11 +31,14 @@ import io.trino.sql.ir.IsNull; import io.trino.sql.ir.Logical; import io.trino.sql.ir.Reference; +import io.trino.sql.ir.WhenClause; import io.trino.sql.planner.assertions.BasePlanTest; import io.trino.type.DateTimes; import io.trino.util.DateTimeUtils; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; + import static io.trino.SystemSessionProperties.PUSH_FILTER_INTO_VALUES_MAX_ROW_COUNT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -344,6 +350,116 @@ public void testGreaterThanOrEqual() testUnwrap("bigint", "a >= DOUBLE '-18446744073709551616'", new Logical(OR, ImmutableList.of(not(new IsNull(new Reference(BIGINT, "a"))), new Constant(BOOLEAN, null)))); } + @Test + public void testBetween() + { + // representable + testUnwrap("smallint", "a BETWEEN DOUBLE '1' AND DOUBLE '2'", new Between(new Reference(SMALLINT, "a"), new Constant(SMALLINT, 1L), new Constant(SMALLINT, 2L))); + testUnwrap("bigint", "a BETWEEN DOUBLE '1' AND DOUBLE '2'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 1L), new Constant(BIGINT, 2L))); + testUnwrap("decimal(7, 2)", "a BETWEEN DOUBLE '1.23' AND DOUBLE '4.56'", new Between(new Reference(createDecimalType(7, 2), "a"), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("1.23"))), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("4.56"))))); + // cast down is possible + testUnwrap("decimal(7, 2)", "CAST(a AS DECIMAL(12,2)) BETWEEN CAST(DECIMAL '111.00' AS decimal(12,2)) AND CAST(DECIMAL '222.0' AS decimal(12,2))", new Between(new Reference(createDecimalType(7, 2), "a"), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("111.00"))), new Constant(createDecimalType(7, 2), Decimals.valueOfShort(new BigDecimal("222.00"))))); + + // non-representable, min cast round up, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '1.1' AND DOUBLE '2.2'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 2L))); + // non-representable, min cast round up, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '1.1' AND DOUBLE '1.1'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + // non-representable, min cast round up, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '1.1' AND DOUBLE '1.9'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + // non-representable, min cast round up, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '1.9' AND DOUBLE '1.9'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(BIGINT, "a")), new Constant(BOOLEAN, null)))); + // non-representable, min cast round down, max cast no rounding + testUnwrap("bigint", "a BETWEEN DOUBLE '1.1' AND DOUBLE '2'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 2L))); + // non-representable, min cast round up, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '1.9' AND DOUBLE '2.2'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 2L))); + // non-representable, min cast round up, max cast round up + testUnwrap("bigint", "a BETWEEN DOUBLE '1.9' AND DOUBLE '2.9'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 2L))); + // non-representable, min cast round up, max cast no rounding + testUnwrap("bigint", "a BETWEEN DOUBLE '1.9' AND DOUBLE '3'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L))); + // non-representable, min cast no rounding, max cast round down + testUnwrap("bigint", "a BETWEEN DOUBLE '2' AND DOUBLE '3.2'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L))); + // non-representable, min cast no rounding, max cast round up + testUnwrap("bigint", "a BETWEEN DOUBLE '2' AND DOUBLE '2.9'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 2L))); + // non-representable, min cast no rounding, max cast no rounding + testUnwrap("bigint", "a BETWEEN DOUBLE '2' AND DOUBLE '3'", new Between(new Reference(BIGINT, "a"), new Constant(BIGINT, 2L), new Constant(BIGINT, 3L))); + + // cast down not possible + testUnwrap( + "decimal(7, 2)", + "CAST(a AS DECIMAL(12,2)) BETWEEN CAST(DECIMAL '1111111111.00' AS decimal(12,2)) AND CAST(DECIMAL '2222222222.0' AS decimal(12,2))", + new Logical(AND, ImmutableList.of(new IsNull(new Reference(createDecimalType(7, 2), "a")), new Constant(BOOLEAN, null)))); + + // illegal range + testUnwrap( + "smallint", + "a BETWEEN DOUBLE '5' AND DOUBLE '4'", + new Case(ImmutableList.of( + new WhenClause(not(new IsNull(new Cast(new Reference(SMALLINT, "a"), DOUBLE))), new Constant(BOOLEAN, false))), + new Constant(BOOLEAN, null))); + + // NULL + testUnwrap( + "smallint", + "a BETWEEN NULL AND DOUBLE '2'", + new Case(ImmutableList.of( + new WhenClause(new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L)), new Constant(BOOLEAN, false))), + new Constant(BOOLEAN, null))); + testUnwrap( + "smallint", + "a BETWEEN DOUBLE '2' AND NULL", + new Case(ImmutableList.of( + new WhenClause(new Comparison(LESS_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 2L)), new Constant(BOOLEAN, false))), + new Constant(BOOLEAN, null))); + + // nan + testUnwrap( + "smallint", + "a BETWEEN nan() AND DOUBLE '2'", + new Case(ImmutableList.of( + new WhenClause(not(new IsNull(new Cast(new Reference(SMALLINT, "a"), DOUBLE))), new Constant(BOOLEAN, false))), + new Constant(BOOLEAN, null))); + testUnwrap( + "smallint", + "a BETWEEN DOUBLE '2' AND nan()", + new Case(ImmutableList.of( + new WhenClause(not(new IsNull(new Cast(new Reference(SMALLINT, "a"), DOUBLE))), new Constant(BOOLEAN, false))), + new Constant(BOOLEAN, null))); + + // min and max below bottom of range + testUnwrap("smallint", "a BETWEEN DOUBLE '-50000' AND DOUBLE '-40000'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + // min below bottom of range, max at the bottom of range + testUnwrap("smallint", "a BETWEEN DOUBLE '-32768.1' AND DOUBLE '-32768'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + testUnwrap("smallint", "a BETWEEN DOUBLE '-32768.1' AND DOUBLE '-32767.9'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32768L))); + // min below bottom of range, max within range + testUnwrap("smallint", "a BETWEEN DOUBLE '-32768.1' AND DOUBLE '0'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 0L))); + // min at the bottom of range, max within range + testUnwrap("smallint", "a BETWEEN DOUBLE '-32768' AND DOUBLE '0'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 0L))); + // min round to bottom of range, max within range + testUnwrap("smallint", "a BETWEEN DOUBLE '-32767.9' AND DOUBLE '0'", new Between(new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L), new Constant(SMALLINT, 0L))); + // min above bottom of range, max within range + testUnwrap("smallint", "a BETWEEN DOUBLE '-32767' AND DOUBLE '0'", new Between(new Reference(SMALLINT, "a"), new Constant(SMALLINT, -32767L), new Constant(SMALLINT, 0L))); + // min & max below within of range + testUnwrap("smallint", "a BETWEEN DOUBLE '32765' AND DOUBLE '32766'", new Between(new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32765L), new Constant(SMALLINT, 32766L))); + // min within and max round to top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '32765.9' AND DOUBLE '32766.9'", new Between(new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32766L), new Constant(SMALLINT, 32766L))); + // min below and max at the top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '32760' AND DOUBLE '32767'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32760L))); + // min below and max above top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '32760.1' AND DOUBLE '32768.1'", new Comparison(GREATER_THAN, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32760L))); + // min at the top of range and max above top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '32767' AND DOUBLE '32768.1'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + testUnwrap("smallint", "a BETWEEN DOUBLE '32766.9' AND DOUBLE '32768.1'", new Comparison(EQUAL, new Reference(SMALLINT, "a"), new Constant(SMALLINT, 32767L))); + // min and max above top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '40000' AND DOUBLE '50000'", new Logical(AND, ImmutableList.of(new IsNull(new Reference(SMALLINT, "a")), new Constant(BOOLEAN, null)))); + // min below range and max at the top of range + testUnwrap("smallint", "a BETWEEN DOUBLE '-40000' AND DOUBLE '32767'", new Logical(OR, ImmutableList.of(not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + // min below range and max above range + testUnwrap("smallint", "a BETWEEN DOUBLE '-40000' AND DOUBLE '40000'", new Logical(OR, ImmutableList.of(not(new IsNull(new Reference(SMALLINT, "a"))), new Constant(BOOLEAN, null)))); + + // -2^64 constant + testUnwrap("bigint", "a BETWEEN DOUBLE '-18446744073709551616' AND DOUBLE '0'", new Comparison(LESS_THAN_OR_EQUAL, new Reference(BIGINT, "a"), new Constant(BIGINT, 0L))); + } + @Test public void testDistinctFrom() { @@ -591,6 +707,7 @@ public void testCastTimestampToTimestampWithTimeZone() // long timestamp, long timestamp with time zone testUnwrap(warsawSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.123456000")))); + testUnwrap(warsawSession, "timestamp(9)", "a BETWEEN TIMESTAMP '2020-10-26 11:02:18.123456 UTC' AND TIMESTAMP '2020-10-26 12:03:20.345678 UTC'", new Between(new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 12:02:18.123456000")), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 13:03:20.345678000")))); testUnwrap(losAngelesSession, "timestamp(9)", "a > TIMESTAMP '2020-10-26 11:02:18.123456 UTC'", new Comparison(GREATER_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "2020-10-26 04:02:18.123456000")))); // maximum precision @@ -726,6 +843,12 @@ public void testUnwrapCastTimestampAsDate() testUnwrap("timestamp(9)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); testUnwrap("timestamp(12)", "CAST(a AS DATE) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + // between + testUnwrap("timestamp(3)", "CAST(a AS DATE) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-07-23 23:59:59.999")))); + testUnwrap("timestamp(6)", "CAST(a AS DATE) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-07-23 23:59:59.999999")))); + testUnwrap("timestamp(9)", "CAST(a AS DATE) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-07-24 00:00:00.000000000")))))); + testUnwrap("timestamp(12)", "CAST(a AS DATE) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-07-24 00:00:00.000000000000")))))); + // is distinct testUnwrap("timestamp(3)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); testUnwrap("timestamp(6)", "CAST(a AS DATE) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(6), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); @@ -792,6 +915,21 @@ public void testUnwrapConvertTimestampToDate() testUnwrap("timestamp(9)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000")))); testUnwrap("timestamp(12)", "date(a) >= DATE '1981-06-22'", new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000")))); + // between + testUnwrap("timestamp(0)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(0), "a"), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "1981-06-22 00:00:00")), new Constant(createTimestampType(0), DateTimes.parseTimestamp(0, "1981-07-23 23:59:59")))); + testUnwrap("timestamp(1)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(1), "a"), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "1981-06-22 00:00:00.0")), new Constant(createTimestampType(1), DateTimes.parseTimestamp(1, "1981-07-23 23:59:59.9")))); + testUnwrap("timestamp(2)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(2), "a"), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "1981-06-22 00:00:00.00")), new Constant(createTimestampType(2), DateTimes.parseTimestamp(2, "1981-07-23 23:59:59.99")))); + testUnwrap("timestamp(3)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000")), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-07-23 23:59:59.999")))); + testUnwrap("timestamp(4)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(4), "a"), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "1981-06-22 00:00:00.0000")), new Constant(createTimestampType(4), DateTimes.parseTimestamp(4, "1981-07-23 23:59:59.9999")))); + testUnwrap("timestamp(5)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(5), "a"), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "1981-06-22 00:00:00.00000")), new Constant(createTimestampType(5), DateTimes.parseTimestamp(5, "1981-07-23 23:59:59.99999")))); + testUnwrap("timestamp(6)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Between(new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000")), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-07-23 23:59:59.999999")))); + testUnwrap("timestamp(7)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "1981-06-22 00:00:00.0000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(7), "a"), new Constant(createTimestampType(7), DateTimes.parseTimestamp(7, "1981-07-24 00:00:00.0000000")))))); + testUnwrap("timestamp(8)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "1981-06-22 00:00:00.00000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(8), "a"), new Constant(createTimestampType(8), DateTimes.parseTimestamp(8, "1981-07-24 00:00:00.00000000")))))); + testUnwrap("timestamp(9)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-06-22 00:00:00.000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(9), "a"), new Constant(createTimestampType(9), DateTimes.parseTimestamp(9, "1981-07-24 00:00:00.000000000")))))); + testUnwrap("timestamp(10)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "1981-06-22 00:00:00.0000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(10), "a"), new Constant(createTimestampType(10), DateTimes.parseTimestamp(10, "1981-07-24 00:00:00.0000000000")))))); + testUnwrap("timestamp(11)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "1981-06-22 00:00:00.00000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(11), "a"), new Constant(createTimestampType(11), DateTimes.parseTimestamp(11, "1981-07-24 00:00:00.00000000000")))))); + testUnwrap("timestamp(12)", "date(a) BETWEEN DATE '1981-06-22' AND DATE '1981-07-23'", new Logical(AND, ImmutableList.of(new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-06-22 00:00:00.000000000000"))), new Comparison(LESS_THAN, new Reference(createTimestampType(12), "a"), new Constant(createTimestampType(12), DateTimes.parseTimestamp(12, "1981-07-24 00:00:00.000000000000")))))); + // is distinct testUnwrap("timestamp(3)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(3), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-22 00:00:00.000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(3), "a"), new Constant(createTimestampType(3), DateTimes.parseTimestamp(3, "1981-06-23 00:00:00.000")))))); testUnwrap("timestamp(6)", "date(a) IS DISTINCT FROM DATE '1981-06-22'", new Logical(OR, ImmutableList.of(new IsNull(new Reference(createTimestampType(6), "a")), new Comparison(LESS_THAN, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-22 00:00:00.000000"))), new Comparison(GREATER_THAN_OR_EQUAL, new Reference(createTimestampType(6), "a"), new Constant(createTimestampType(6), DateTimes.parseTimestamp(6, "1981-06-23 00:00:00.000000")))))); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestUnwrapCastInComparison.java b/core/trino-main/src/test/java/io/trino/sql/query/TestUnwrapCastInComparison.java index b7e3f6db96dd..4c95f56356b6 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestUnwrapCastInComparison.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestUnwrapCastInComparison.java @@ -75,6 +75,14 @@ public void testTinyint() validate(operator, fromType, from, "DOUBLE", to); } } + + for (Number to : asList(null, Byte.MIN_VALUE - 1, Byte.MIN_VALUE, 0, 1, Byte.MAX_VALUE, Byte.MAX_VALUE + 1)) { + validateBetween(fromType, from, "SMALLINT", to, to); + validateBetween(fromType, from, "INTEGER", to, to); + validateBetween(fromType, from, "BIGINT", to, to); + validateBetween(fromType, from, "REAL", to, to); + validateBetween(fromType, from, "DOUBLE", to, to); + } } } @@ -100,6 +108,13 @@ public void testSmallint() validate(operator, fromType, from, "DOUBLE", to); } } + + for (Number to : asList(null, Short.MIN_VALUE - 1, Short.MIN_VALUE, 0, 1, Short.MAX_VALUE, Short.MAX_VALUE + 1)) { + validateBetween(fromType, from, "INTEGER", to, to); + validateBetween(fromType, from, "BIGINT", to, to); + validateBetween(fromType, from, "REAL", to, to); + validateBetween(fromType, from, "DOUBLE", to, to); + } } } @@ -121,6 +136,16 @@ public void testInteger() validate(operator, fromType, from, "REAL", to); } } + + for (Number to : asList(null, Integer.MIN_VALUE - 1L, Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE, Integer.MAX_VALUE + 1L)) { + validateBetween(fromType, from, "BIGINT", to, to); + } + for (Number to : asList(null, Integer.MIN_VALUE - 1L, Integer.MIN_VALUE, 0, 0.1, 0.9, 1, Integer.MAX_VALUE, Integer.MAX_VALUE + 1L)) { + validateBetween(fromType, from, "DOUBLE", to, to); + } + for (Number to : asList(null, Integer.MIN_VALUE - 1L, Integer.MIN_VALUE, -1L << 23 + 1, 0, 0.1, 0.9, 1, 1L << 23 - 1, Integer.MAX_VALUE, Integer.MAX_VALUE + 1L)) { + validateBetween(fromType, from, "REAL", to, to); + } } } @@ -138,6 +163,13 @@ public void testBigint() validate(operator, fromType, from, "REAL", to); } } + + for (Number to : asList(null, Long.MIN_VALUE, Long.MIN_VALUE + 1, -1L << 53 + 1, 0, 0.1, 0.9, 1, 1L << 53 - 1, Long.MAX_VALUE - 1, Long.MAX_VALUE)) { + validateBetween(fromType, from, "DOUBLE", to, to); + } + for (Number to : asList(null, Long.MIN_VALUE, Long.MIN_VALUE + 1, -1L << 23 + 1, 0, 0.1, 0.9, 1, 1L << 23 - 1, Long.MAX_VALUE - 1, Long.MAX_VALUE)) { + validateBetween(fromType, from, "REAL", to, to); + } } } @@ -153,6 +185,9 @@ public void testReal() validate(operator, fromType, from, toType, to); } } + for (String to : toLiteral(toType, asList(null, Double.NEGATIVE_INFINITY, Math.nextDown((double) -Float.MIN_VALUE), (double) -Float.MIN_VALUE, 0, 0.1, 0.9, 1, (double) Float.MAX_VALUE, Math.nextUp((double) Float.MAX_VALUE), Double.POSITIVE_INFINITY, Double.NaN))) { + validateBetween(fromType, from, toType, to, to); + } } } @@ -167,6 +202,9 @@ public void testDecimal() validate(operator, "DECIMAL(15, 0)", from, "DOUBLE", Double.valueOf(to)); } } + for (String to : values) { + validateBetween("DECIMAL(15, 0)", from, "DOUBLE", Double.valueOf(to), Double.valueOf(to)); + } } // decimal(16) -> double @@ -177,6 +215,9 @@ public void testDecimal() validate(operator, "DECIMAL(16, 0)", from, "DOUBLE", Double.valueOf(to)); } } + for (String to : values) { + validateBetween("DECIMAL(16, 0)", from, "DOUBLE", Double.valueOf(to), Double.valueOf(to)); + } } // decimal(7) -> real @@ -187,6 +228,9 @@ public void testDecimal() validate(operator, "DECIMAL(7, 0)", from, "REAL", Double.valueOf(to)); } } + for (String to : values) { + validateBetween("DECIMAL(7, 0)", from, "REAL", Double.valueOf(to), Double.valueOf(to)); + } } // decimal(8) -> real @@ -197,6 +241,9 @@ public void testDecimal() validate(operator, "DECIMAL(8, 0)", from, "REAL", Double.valueOf(to)); } } + for (String to : values) { + validateBetween("DECIMAL(8, 0)", from, "REAL", Double.valueOf(to), Double.valueOf(to)); + } } } @@ -209,6 +256,9 @@ public void testVarchar() validate(operator, "VARCHAR(1)", from, "VARCHAR(2)", to); } } + for (String to : asList(null, "''", "'a'", "'aa'", "'b'", "'bb'")) { + validateBetween("VARCHAR(1)", from, "VARCHAR(2)", to, to); + } } // type with no range @@ -217,6 +267,9 @@ public void testVarchar() validate(operator, "VARCHAR(200)", "'" + "a".repeat(200) + "'", "VARCHAR(300)", to); } } + for (String to : asList("'" + "a".repeat(200) + "'", "'" + "b".repeat(200) + "'")) { + validateBetween("VARCHAR(200)", "'" + "a".repeat(200) + "'", "VARCHAR(300)", to, to); + } } @Test @@ -327,6 +380,15 @@ public void testCastTimestampToTimestampWithTimeZone() validate(session, operator, "timestamp(12)", "TIMESTAMP '2020-07-03 01:23:45.123456789123'", "timestamp(12) with time zone", "TIMESTAMP '2020-07-03 01:23:45 UTC'"); } + validateBetween(session, "timestamp(3)", "TIMESTAMP '2020-07-03 01:23:45.123'", "timestamp(3) with time zone", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'"); + validateBetween(session, "timestamp(3)", "TIMESTAMP '2020-07-03 01:23:45.123'", "timestamp(3) with time zone", "TIMESTAMP '2020-07-03 01:23:45 UTC'", "TIMESTAMP '2020-07-03 01:23:45 UTC'"); + validateBetween(session, "timestamp(6)", "TIMESTAMP '2020-07-03 01:23:45.123456'", "timestamp(6) with time zone", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'"); + validateBetween(session, "timestamp(6)", "TIMESTAMP '2020-07-03 01:23:45.123456'", "timestamp(6) with time zone", "TIMESTAMP '2020-07-03 01:23:45 UTC'", "TIMESTAMP '2020-07-03 01:23:45 UTC'"); + validateBetween(session, "timestamp(9)", "TIMESTAMP '2020-07-03 01:23:45.123456789'", "timestamp(9) with time zone", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'"); + validateBetween(session, "timestamp(9)", "TIMESTAMP '2020-07-03 01:23:45.123456789'", "timestamp(9) with time zone", "TIMESTAMP '2020-07-03 01:23:45 UTC'", "TIMESTAMP '2020-07-03 01:23:45 UTC'"); + validateBetween(session, "timestamp(12)", "TIMESTAMP '2020-07-03 01:23:45.123456789123'", "timestamp(12) with time zone", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'", "TIMESTAMP '2020-07-03 01:23:45 Europe/Warsaw'"); + validateBetween(session, "timestamp(12)", "TIMESTAMP '2020-07-03 01:23:45.123456789123'", "timestamp(12) with time zone", "TIMESTAMP '2020-07-03 01:23:45 UTC'", "TIMESTAMP '2020-07-03 01:23:45 UTC'"); + // DST forward change (2017-09-24 03:00 -> 2017-09-24 04:00) List fromLocalTimes = asList( LocalTime.parse("02:59:59.999999999"), @@ -459,6 +521,32 @@ private void validate(Session session, String operator, String fromType, Object .isTrue(); } + private void validateBetween(String fromType, Object fromValue, String toType, Object minValue, Object maxValue) + { + validateBetween(assertions.getDefaultSession(), fromType, fromValue, toType, minValue, maxValue); + } + + private void validateBetween(Session session, String fromType, Object fromValue, String toType, Object minValue, Object maxValue) + { + String query = format( + "SELECT (CAST(v AS %s) BETWEEN CAST(%s AS %s) AND CAST(%s AS %s)) " + + "IS NOT DISTINCT FROM " + + "(CAST(%s AS %s) BETWEEN CAST(%s AS %s) AND CAST(%s AS %s)) " + + "FROM (VALUES CAST(%s AS %s)) t(v)", + toType, minValue, toType, maxValue, toType, + fromValue, toType, minValue, toType, maxValue, toType, + fromValue, fromType); + + boolean result = (boolean) assertions.execute(session, query) + .getMaterializedRows() + .get(0) + .getField(0); + + assertThat(result) + .as("Query evaluated to false: " + query) + .isTrue(); + } + @Test public void testUnwrapTimestampToDate() { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 670200247be5..1d5eaa3b19f6 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -2071,15 +2071,27 @@ else if (format == AVRO) { .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE CAST(d AS DATE) BETWEEN DATE '2015-05-15' AND DATE '2015-06-15'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE d >= TIMESTAMP '2015-05-15 12:00:00'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE d BETWEEN TIMESTAMP '2015-05-15 12:00:00' AND TIMESTAMP '2015-06-15 11:59:59.999999'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE d >= TIMESTAMP '2015-05-15 12:00:00.000001'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE d BETWEEN TIMESTAMP '2015-05-15 12:00:00.000001' AND TIMESTAMP '2015-06-15 11:59:59.999999'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE d BETWEEN TIMESTAMP '2015-05-15 12:00:00' AND TIMESTAMP '2015-06-15 12:00:00.00000'")) + .isNotFullyPushedDown(FilterNode.class); // date() assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE date(d) = DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE date(d) BETWEEN DATE '2015-05-15' AND DATE '2015-06-15'")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE date(d) BETWEEN TIMESTAMP '2015-05-15 12:00:00' AND TIMESTAMP '2015-06-15 12:00:00.00000'")) + .isFullyPushedDown(); // year() assertThat(query("SELECT * FROM test_hour_transform_timestamp WHERE year(d) = 2015")) @@ -2177,6 +2189,8 @@ else if (format == AVRO) { .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE CAST(d AS date) BETWEEN DATE '2015-05-15' AND DATE '2015-06-15'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE d >= TIMESTAMP '2015-05-15 12:00:00 UTC'")) .isFullyPushedDown(); @@ -2186,6 +2200,8 @@ else if (format == AVRO) { // date() assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE date(d) = DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE date(d) BETWEEN DATE '2015-05-15' AND DATE '2015-06-15'")) + .isFullyPushedDown(); // year() assertThat(query("SELECT * FROM test_hour_transform_timestamptz WHERE year(d) = 2015")) @@ -2357,6 +2373,8 @@ public void testDayTransformDate() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_day_transform_date WHERE CAST(d AS date) >= DATE '2015-01-13'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_date WHERE d BETWEEN DATE '2015-01-13' AND DATE '2015-01-14'")) + .isFullyPushedDown(); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_day_transform_date WHERE d >= TIMESTAMP '2015-01-13 00:00:00'")) @@ -2470,15 +2488,23 @@ else if (format == AVRO) { .isFullyPushedDown(); assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d BETWEEN DATE '2015-05-15' AND DATE '2015-05-16'")) + .isNotFullyPushedDown(FilterNode.class); assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d >= TIMESTAMP '2015-05-15 00:00:00'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d BETWEEN TIMESTAMP '2015-05-15 00:00:00' AND TIMESTAMP '2015-05-16 23:59:59.999999'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d >= TIMESTAMP '2015-05-15 00:00:00.000001'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE d BETWEEN TIMESTAMP '2015-05-15 00:00:00' AND TIMESTAMP '2015-05-16 00:00:00'")) + .isNotFullyPushedDown(FilterNode.class); // date() assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE date(d) = DATE '2015-05-15'")) .isFullyPushedDown(); + assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE date(d) BETWEEN DATE '2015-05-15' AND DATE '2015-06-15'")) + .isFullyPushedDown(); // year() assertThat(query("SELECT * FROM test_day_transform_timestamp WHERE year(d) = 2015")) @@ -2693,12 +2719,16 @@ public void testMonthTransformDate() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_month_transform_date WHERE CAST(d AS date) >= DATE '2020-06-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_date WHERE d BETWEEN DATE '2020-06-01' AND DATE '2020-07-31'")) + .isFullyPushedDown(); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_month_transform_date WHERE d >= TIMESTAMP '2015-06-01 00:00:00'")) .isFullyPushedDown(); assertThat(query("SELECT * FROM test_month_transform_date WHERE d >= TIMESTAMP '2015-05-01 00:00:00.000001'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_date WHERE d BETWEEN TIMESTAMP '2015-05-01 00:00:00' AND TIMESTAMP '2015-06-30 00:00:00'")) + .isFullyPushedDown(); // year() assertThat(query("SELECT * FROM test_month_transform_date WHERE year(d) = 2015")) @@ -2824,6 +2854,11 @@ else if (format == AVRO) { assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE d >= TIMESTAMP '2015-05-01 00:00:00.000001'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE d BETWEEN DATE '2015-05-01' AND DATE '2015-06-01'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE d BETWEEN DATE '2015-05-01' AND TIMESTAMP '2015-05-31 23:59:59.999999'")) + .isFullyPushedDown(); + // year() assertThat(query("SELECT * FROM test_month_transform_timestamp WHERE year(d) = 2015")) .isFullyPushedDown(); @@ -3028,6 +3063,10 @@ public void testYearTransformDate() .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_date WHERE CAST(d AS date) >= DATE '2015-01-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_date WHERE d BETWEEN DATE '2015-01-01' AND DATE '2016-01-01'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_date WHERE d BETWEEN DATE '2015-01-01' AND TIMESTAMP '2015-12-31 23:59:59.999999'")) + .isFullyPushedDown(); // d comparison with TIMESTAMP can be unwrapped assertThat(query("SELECT * FROM test_year_transform_date WHERE d >= TIMESTAMP '2015-01-01 00:00:00'")) @@ -3148,6 +3187,8 @@ else if (format == AVRO) { .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE CAST(d AS date) >= DATE '2015-01-02'")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE CAST(d AS date) BETWEEN DATE '2015-01-01' AND DATE '2016-12-31'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamp WHERE d >= TIMESTAMP '2015-01-01 00:00:00'")) .isFullyPushedDown(); @@ -3243,6 +3284,10 @@ else if (format == AVRO) { .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d >= with_timezone(DATE '2015-01-02', 'UTC')")) .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d BETWEEN DATE '2015-01-01' AND DATE '2016-01-01'")) + .isNotFullyPushedDown(FilterNode.class); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d BETWEEN with_timezone(DATE '2015-01-01', 'UTC') AND with_timezone(TIMESTAMP '2016-12-31 23:59:59.999999', 'UTC')")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE CAST(d AS date) >= DATE '2015-01-01'")) .isFullyPushedDown(); @@ -3252,6 +3297,8 @@ else if (format == AVRO) { // Engine can eliminate the table scan after connector accepts the filter pushdown .hasPlan(node(OutputNode.class, node(ValuesNode.class))) .returnsEmptyResult(); + assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE CAST(d AS date) BETWEEN DATE '2015-01-01' AND DATE '2016-12-31'")) + .isFullyPushedDown(); assertThat(query("SELECT * FROM test_year_transform_timestamptz WHERE d >= TIMESTAMP '2015-01-01 00:00:00 UTC'")) .isFullyPushedDown(); From ca41012bdb42672e85f7c5f83470581b51cf2442 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Sat, 28 Dec 2024 00:54:18 +0100 Subject: [PATCH 3/4] empty From 1408d24d9db6566d09b9beaf913b0b2d52ff4474 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Sun, 29 Dec 2024 08:35:53 +0100 Subject: [PATCH 4/4] empty