-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Unwrap casts in BETWEEN predicate
#14452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| import io.trino.spi.type.CharType; | ||
| import io.trino.spi.type.DecimalType; | ||
| import io.trino.spi.type.DoubleType; | ||
| import io.trino.spi.type.LongTimestamp; | ||
| import io.trino.spi.type.LongTimestampWithTimeZone; | ||
| import io.trino.spi.type.RealType; | ||
| import io.trino.spi.type.TimeWithTimeZoneType; | ||
|
|
@@ -39,6 +40,7 @@ | |
| import io.trino.sql.planner.NoOpSymbolResolver; | ||
| import io.trino.sql.planner.TypeAnalyzer; | ||
| import io.trino.sql.planner.TypeProvider; | ||
| import io.trino.sql.tree.BetweenPredicate; | ||
| import io.trino.sql.tree.Cast; | ||
| import io.trino.sql.tree.ComparisonExpression; | ||
| import io.trino.sql.tree.Expression; | ||
|
|
@@ -69,6 +71,7 @@ | |
| import static io.trino.spi.type.DoubleType.DOUBLE; | ||
| import static io.trino.spi.type.IntegerType.INTEGER; | ||
| import static io.trino.spi.type.RealType.REAL; | ||
| import static io.trino.spi.type.TimestampType.createTimestampType; | ||
| import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; | ||
| import static io.trino.spi.type.TypeUtils.isFloatingPointNaN; | ||
| import static io.trino.sql.ExpressionUtils.and; | ||
|
|
@@ -81,6 +84,8 @@ | |
| import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; | ||
| import static io.trino.type.DateTimes.PICOSECONDS_PER_MICROSECOND; | ||
| import static io.trino.type.DateTimes.scaleFactor; | ||
| import static java.lang.Float.intBitsToFloat; | ||
| import static java.lang.Math.toIntExact; | ||
| import static java.util.Objects.requireNonNull; | ||
|
|
@@ -171,6 +176,13 @@ public Expression rewriteComparisonExpression(ComparisonExpression node, Void co | |
| return unwrapCast(expression); | ||
| } | ||
|
|
||
| @Override | ||
| public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter) | ||
| { | ||
| BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null); | ||
| return unwrapCast(expression); | ||
| } | ||
|
|
||
| private Expression unwrapCast(ComparisonExpression expression) | ||
| { | ||
| // Canonicalization is handled by CanonicalizeExpressionRewriter | ||
|
|
@@ -387,6 +399,208 @@ private Optional<Expression> unwrapTimestampToDateCast(Session session, Timestam | |
| }; | ||
| } | ||
|
|
||
| private Expression unwrapCast(BetweenPredicate expression) | ||
| { | ||
| // Canonicalization is handled by CanonicalizeExpressionRewriter | ||
| if (!(expression.getValue() instanceof Cast cast)) { | ||
| return expression; | ||
| } | ||
|
|
||
| Object rangeMin = new ExpressionInterpreter(expression.getMin(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMin())) | ||
| .optimize(NoOpSymbolResolver.INSTANCE); | ||
| Object rangeMax = new ExpressionInterpreter(expression.getMax(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMax())) | ||
| .optimize(NoOpSymbolResolver.INSTANCE); | ||
|
|
||
| if (rangeMin == null || rangeMin instanceof NullLiteral) { | ||
| return new Cast(new NullLiteral(), toSqlType(BOOLEAN)); | ||
| } | ||
|
|
||
| if (rangeMin instanceof Expression || rangeMax instanceof Expression) { | ||
| return expression; | ||
| } | ||
|
|
||
| Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression()); | ||
| Type rangeMinType = typeAnalyzer.getType(session, types, expression.getMin()); | ||
| Type rangeMaxType = typeAnalyzer.getType(session, types, expression.getMax()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like duplicating the logic quite much. Did you try something like @Override
public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null);
ComparisonExpression lowBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
ComparisonExpression highBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
Expression lowBoundUnwrapped = unwrapCast(lowBound);
Expression highBoundUnwrapped = unwrapCast(highBound);
if (lowBound.equals(lowBoundUnwrapped) && highBound.equals(highBoundUnwrapped)) {
return expression;
}
...
}?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created #14648 with inspiration from #12797 I stumbled while trying to add the methods:
to As an alternative we could use custom operators included in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is now covered by #12797. |
||
| if (rangeMinType != rangeMaxType) { | ||
| return expression; | ||
| } | ||
| Type targetType = rangeMinType; | ||
|
|
||
| if (sourceType instanceof TimestampType && targetType == DATE) { | ||
| return unwrapTimestampToDateCastForRange(session, (TimestampType) sourceType, cast.getExpression(), (long) rangeMin, (long) rangeMax); | ||
| } | ||
|
|
||
| if (targetType instanceof TimestampWithTimeZoneType) { | ||
| // Note: two TIMESTAMP WITH TIME ZONE values differing in zone only (same instant) are considered equal. | ||
| rangeMin = withTimeZone(((TimestampWithTimeZoneType) targetType), rangeMin, session.getTimeZoneKey()); | ||
| rangeMax = withTimeZone(((TimestampWithTimeZoneType) targetType), rangeMax, session.getTimeZoneKey()); | ||
| } | ||
|
|
||
| if (!hasInjectiveImplicitCoercion(sourceType, targetType, rangeMin)) { | ||
| return expression; | ||
| } | ||
|
|
||
| // Handle comparison against NaN. | ||
| // It must be done before source type range bounds are compared to target value. | ||
| if (isFloatingPointNaN(targetType, rangeMin) || isFloatingPointNaN(targetType, rangeMax)) { | ||
| return falseIfNotNull(cast.getExpression()); | ||
| } | ||
|
|
||
| if (compare(targetType, rangeMin, rangeMax) > 0) { | ||
| // range min gte range max | ||
| return falseIfNotNull(cast.getExpression()); | ||
| } | ||
|
|
||
| ResolvedFunction sourceToTarget = plannerContext.getMetadata().getCoercion(session, sourceType, targetType); | ||
|
|
||
| Optional<Type.Range> sourceRange = sourceType.getRange(); | ||
| if (sourceRange.isPresent()) { | ||
| Object maxInSourceType = sourceRange.get().getMax(); | ||
| Object maxInTargetType = coerce(maxInSourceType, sourceToTarget); | ||
|
|
||
| // NaN values of `rangeMin` and `rangeMax` are excluded at this point. Otherwise, NaN would be recognized as | ||
| // greater than source type upper bound, and incorrect expression might be derived. | ||
| int upperBoundRangeMinComparison = compare(targetType, rangeMin, maxInTargetType); | ||
| if (upperBoundRangeMinComparison > 0) { | ||
| // range min is larger than maximum representable value | ||
| return falseIfNotNull(cast.getExpression()); | ||
| } | ||
| if (upperBoundRangeMinComparison == 0) { | ||
| // range min equal to max representable value | ||
| return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, maxInSourceType, sourceType)); | ||
| } | ||
| int upperBoundRangeMaxComparison = compare(targetType, rangeMax, maxInTargetType); | ||
| if (upperBoundRangeMaxComparison >= 0) { | ||
| // range max larger or equal to the maximum representable value | ||
| return unwrapCast(new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast, expression.getMin())); | ||
| } | ||
|
|
||
| Object minInSourceType = sourceRange.get().getMin(); | ||
| Object minInTargetType = coerce(minInSourceType, sourceToTarget); | ||
|
|
||
| int lowerBoundRageMaxComparison = compare(targetType, rangeMax, minInTargetType); | ||
| if (lowerBoundRageMaxComparison < 0) { | ||
| // range max smaller than minimum representable value | ||
| return falseIfNotNull(cast.getExpression()); | ||
| } | ||
| if (lowerBoundRageMaxComparison == 0) { | ||
| // range max equal to min representable value | ||
| return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, minInSourceType, sourceType)); | ||
| } | ||
| int lowerBoundRageMinComparison = compare(targetType, rangeMin, minInTargetType); | ||
| if (lowerBoundRageMinComparison <= 0) { | ||
| // range min smaller or equal to the minimum representable value | ||
| return unwrapCast(new ComparisonExpression(LESS_THAN_OR_EQUAL, cast, expression.getMax())); | ||
| } | ||
| } | ||
|
|
||
| ResolvedFunction targetToSource; | ||
| try { | ||
| targetToSource = plannerContext.getMetadata().getCoercion(session, targetType, sourceType); | ||
| } | ||
| catch (OperatorNotFoundException e) { | ||
| // Without a cast between target -> source, there's nothing more we can do | ||
| return expression; | ||
| } | ||
|
|
||
| Object rangeMinLiteralInSourceType; | ||
| Object rangeMaxLiteralInSourceType; | ||
| try { | ||
| rangeMinLiteralInSourceType = coerce(rangeMin, targetToSource); | ||
| rangeMaxLiteralInSourceType = coerce(rangeMax, targetToSource); | ||
| } | ||
| catch (TrinoException e) { | ||
| // A failure to cast from target -> source type could be because: | ||
| // 1. missing cast | ||
| // 2. bad implementation | ||
| // 3. out of range or otherwise unrepresentable value | ||
| // Since we can't distinguish between those cases, take the conservative option | ||
| // and bail out. | ||
| return expression; | ||
| } | ||
|
|
||
| Object rangeMinRoundtripLiteral = coerce(rangeMinLiteralInSourceType, sourceToTarget); | ||
| int rangeMinLiteralVsRoundtripped = compare(targetType, rangeMin, rangeMinRoundtripLiteral); | ||
| Object rangeMaxRoundtripLiteral = coerce(rangeMaxLiteralInSourceType, sourceToTarget); | ||
| int rangeMaxLiteralVsRoundtripped = compare(targetType, rangeMax, rangeMaxRoundtripLiteral); | ||
|
|
||
| Expression rangeMinLiteralExpression = literalEncoder.toExpression(session, rangeMinLiteralInSourceType, sourceType); | ||
| Expression rangeMaxLiteralExpression = literalEncoder.toExpression(session, rangeMaxLiteralInSourceType, sourceType); | ||
| if (rangeMinLiteralVsRoundtripped > 0) { | ||
| // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source | ||
| // cannot produce a value larger than the next value in the source type | ||
| ComparisonExpression rangeMinComparisonExpression = new ComparisonExpression(GREATER_THAN, cast.getExpression(), rangeMinLiteralExpression); | ||
| if (rangeMaxLiteralVsRoundtripped >= 0) { | ||
| return and( | ||
| rangeMinComparisonExpression, | ||
| new ComparisonExpression(LESS_THAN_OR_EQUAL, cast.getExpression(), rangeMaxLiteralExpression)); | ||
| } | ||
| else { | ||
| return and( | ||
| rangeMinComparisonExpression, | ||
| // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot | ||
| // produce a value smaller than the next value in the source type | ||
| new ComparisonExpression(LESS_THAN, cast.getExpression(), rangeMaxLiteralExpression)); | ||
| } | ||
| } | ||
| else { | ||
| ComparisonExpression rangeMinComparisonExpression = new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast.getExpression(), rangeMinLiteralExpression); | ||
| if (rangeMaxLiteralVsRoundtripped >= 0) { | ||
| return new BetweenPredicate(cast.getExpression(), rangeMinLiteralExpression, rangeMaxLiteralExpression); | ||
| } | ||
| else { | ||
| return and( | ||
| rangeMinComparisonExpression, | ||
| // We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot | ||
| // produce a value smaller than the next value in the source type | ||
| new ComparisonExpression(LESS_THAN, cast.getExpression(), rangeMaxLiteralExpression)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private Expression unwrapTimestampToDateCastForRange( | ||
| Session session, | ||
| TimestampType sourceType, | ||
| Expression timestampExpression, | ||
| long minDateInclusive, | ||
| long maxDateInclusive) | ||
| { | ||
| ResolvedFunction targetToSource; | ||
| try { | ||
| targetToSource = plannerContext.getMetadata().getCoercion(session, DATE, sourceType); | ||
| } | ||
| catch (OperatorNotFoundException e) { | ||
| throw new TrinoException(GENERIC_INTERNAL_ERROR, e); | ||
| } | ||
|
|
||
| Expression minDateInclusiveTimestamp = literalEncoder.toExpression(session, coerce(minDateInclusive, targetToSource), sourceType); | ||
| Expression maxDateInclusiveTimestamp; | ||
| if (sourceType.isShort()) { | ||
| long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSource); | ||
| long maxTimestampInclusiveMicros = maxDateExclusive - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_SHORT_PRECISION); | ||
| maxDateInclusiveTimestamp = literalEncoder.toExpression(session, maxTimestampInclusiveMicros, sourceType); | ||
| } | ||
| else { | ||
| ResolvedFunction targetToSourceShortTimestamp; | ||
| try { | ||
| targetToSourceShortTimestamp = plannerContext.getMetadata().getCoercion(session, DATE, createTimestampType(TimestampType.MAX_SHORT_PRECISION)); | ||
| } | ||
| catch (OperatorNotFoundException e) { | ||
| throw new TrinoException(GENERIC_INTERNAL_ERROR, e); | ||
| } | ||
| long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSourceShortTimestamp); | ||
| long maxTimestampInclusiveMicros = maxDateExclusive - 1; | ||
| int picosOfMicro = toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_PRECISION)); | ||
| maxDateInclusiveTimestamp = literalEncoder.toExpression( | ||
| session, | ||
| new LongTimestamp(maxTimestampInclusiveMicros, picosOfMicro), | ||
| sourceType); | ||
| } | ||
|
|
||
| return new BetweenPredicate(timestampExpression, minDateInclusiveTimestamp, maxDateInclusiveTimestamp); | ||
| } | ||
|
|
||
| private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object value) | ||
| { | ||
| if ((source.equals(BIGINT) && target.equals(DOUBLE)) || | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.