Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.RealType;
import io.trino.spi.type.TimeWithTimeZoneType;
Expand All @@ -39,6 +40,7 @@
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
Expand Down Expand Up @@ -69,6 +71,7 @@
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.TimestampType.createTimestampType;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
import static io.trino.sql.ExpressionUtils.and;
Expand All @@ -81,6 +84,8 @@
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static io.trino.type.DateTimes.PICOSECONDS_PER_MICROSECOND;
import static io.trino.type.DateTimes.scaleFactor;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -171,6 +176,13 @@ public Expression rewriteComparisonExpression(ComparisonExpression node, Void co
return unwrapCast(expression);
}

@Override
public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
Comment thread
findinpath marked this conversation as resolved.
Outdated
{
BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null);
return unwrapCast(expression);
}

private Expression unwrapCast(ComparisonExpression expression)
{
// Canonicalization is handled by CanonicalizeExpressionRewriter
Expand Down Expand Up @@ -387,6 +399,208 @@ private Optional<Expression> unwrapTimestampToDateCast(Session session, Timestam
};
}

private Expression unwrapCast(BetweenPredicate expression)
{
// Canonicalization is handled by CanonicalizeExpressionRewriter
if (!(expression.getValue() instanceof Cast cast)) {
return expression;
}

Object rangeMin = new ExpressionInterpreter(expression.getMin(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMin()))
.optimize(NoOpSymbolResolver.INSTANCE);
Object rangeMax = new ExpressionInterpreter(expression.getMax(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMax()))
.optimize(NoOpSymbolResolver.INSTANCE);

if (rangeMin == null || rangeMin instanceof NullLiteral) {
return new Cast(new NullLiteral(), toSqlType(BOOLEAN));
}

if (rangeMin instanceof Expression || rangeMax instanceof Expression) {
return expression;
}

Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression());
Type rangeMinType = typeAnalyzer.getType(session, types, expression.getMin());
Type rangeMaxType = typeAnalyzer.getType(session, types, expression.getMax());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like duplicating the logic quite much.

Did you try something like

@Override
public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
    BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null);

    ComparisonExpression lowBound = new ComparisonExpression(GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
    ComparisonExpression highBound = new ComparisonExpression(LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());

    Expression lowBoundUnwrapped = unwrapCast(lowBound);
    Expression highBoundUnwrapped = unwrapCast(highBound);
    if (lowBound.equals(lowBoundUnwrapped) && highBound.equals(highBoundUnwrapped)) {
        return expression;
    }
    ...
}

?
Why it cannot be made to work?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #14648 with inspiration from #12797

I stumbled while trying to add the methods:

  • getPreviousValue(Object)
  • getNextValue(Object)

to DateType.
I am assuming that checking the validity of a date needs joda-time expertise which is not available in trino-spi module.

As an alternative we could use custom operators included in DateOperators to cover this need.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stumbled while trying to add the methods:

  • getPreviousValue(Object)
  • getNextValue(Object)

to DateType.

This is now covered by #12797.

if (rangeMinType != rangeMaxType) {
return expression;
}
Type targetType = rangeMinType;

if (sourceType instanceof TimestampType && targetType == DATE) {
return unwrapTimestampToDateCastForRange(session, (TimestampType) sourceType, cast.getExpression(), (long) rangeMin, (long) rangeMax);
}

if (targetType instanceof TimestampWithTimeZoneType) {
// Note: two TIMESTAMP WITH TIME ZONE values differing in zone only (same instant) are considered equal.
rangeMin = withTimeZone(((TimestampWithTimeZoneType) targetType), rangeMin, session.getTimeZoneKey());
rangeMax = withTimeZone(((TimestampWithTimeZoneType) targetType), rangeMax, session.getTimeZoneKey());
}

if (!hasInjectiveImplicitCoercion(sourceType, targetType, rangeMin)) {
return expression;
}

// Handle comparison against NaN.
// It must be done before source type range bounds are compared to target value.
if (isFloatingPointNaN(targetType, rangeMin) || isFloatingPointNaN(targetType, rangeMax)) {
return falseIfNotNull(cast.getExpression());
}

if (compare(targetType, rangeMin, rangeMax) > 0) {
// range min gte range max
return falseIfNotNull(cast.getExpression());
}

ResolvedFunction sourceToTarget = plannerContext.getMetadata().getCoercion(session, sourceType, targetType);

Optional<Type.Range> sourceRange = sourceType.getRange();
if (sourceRange.isPresent()) {
Object maxInSourceType = sourceRange.get().getMax();
Object maxInTargetType = coerce(maxInSourceType, sourceToTarget);

// NaN values of `rangeMin` and `rangeMax` are excluded at this point. Otherwise, NaN would be recognized as
// greater than source type upper bound, and incorrect expression might be derived.
int upperBoundRangeMinComparison = compare(targetType, rangeMin, maxInTargetType);
if (upperBoundRangeMinComparison > 0) {
// range min is larger than maximum representable value
return falseIfNotNull(cast.getExpression());
}
if (upperBoundRangeMinComparison == 0) {
// range min equal to max representable value
return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, maxInSourceType, sourceType));
}
int upperBoundRangeMaxComparison = compare(targetType, rangeMax, maxInTargetType);
if (upperBoundRangeMaxComparison >= 0) {
// range max larger or equal to the maximum representable value
return unwrapCast(new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast, expression.getMin()));
}

Object minInSourceType = sourceRange.get().getMin();
Object minInTargetType = coerce(minInSourceType, sourceToTarget);

int lowerBoundRageMaxComparison = compare(targetType, rangeMax, minInTargetType);
if (lowerBoundRageMaxComparison < 0) {
// range max smaller than minimum representable value
return falseIfNotNull(cast.getExpression());
}
if (lowerBoundRageMaxComparison == 0) {
// range max equal to min representable value
return new ComparisonExpression(EQUAL, cast.getExpression(), literalEncoder.toExpression(session, minInSourceType, sourceType));
}
int lowerBoundRageMinComparison = compare(targetType, rangeMin, minInTargetType);
if (lowerBoundRageMinComparison <= 0) {
// range min smaller or equal to the minimum representable value
return unwrapCast(new ComparisonExpression(LESS_THAN_OR_EQUAL, cast, expression.getMax()));
}
}

ResolvedFunction targetToSource;
try {
targetToSource = plannerContext.getMetadata().getCoercion(session, targetType, sourceType);
}
catch (OperatorNotFoundException e) {
// Without a cast between target -> source, there's nothing more we can do
return expression;
}

Object rangeMinLiteralInSourceType;
Object rangeMaxLiteralInSourceType;
try {
rangeMinLiteralInSourceType = coerce(rangeMin, targetToSource);
rangeMaxLiteralInSourceType = coerce(rangeMax, targetToSource);
}
catch (TrinoException e) {
// A failure to cast from target -> source type could be because:
// 1. missing cast
// 2. bad implementation
// 3. out of range or otherwise unrepresentable value
// Since we can't distinguish between those cases, take the conservative option
// and bail out.
return expression;
}

Object rangeMinRoundtripLiteral = coerce(rangeMinLiteralInSourceType, sourceToTarget);
int rangeMinLiteralVsRoundtripped = compare(targetType, rangeMin, rangeMinRoundtripLiteral);
Object rangeMaxRoundtripLiteral = coerce(rangeMaxLiteralInSourceType, sourceToTarget);
int rangeMaxLiteralVsRoundtripped = compare(targetType, rangeMax, rangeMaxRoundtripLiteral);

Expression rangeMinLiteralExpression = literalEncoder.toExpression(session, rangeMinLiteralInSourceType, sourceType);
Expression rangeMaxLiteralExpression = literalEncoder.toExpression(session, rangeMaxLiteralInSourceType, sourceType);
if (rangeMinLiteralVsRoundtripped > 0) {
// We expect implicit coercions to be order-preserving, so the result of converting back from target -> source
// cannot produce a value larger than the next value in the source type
ComparisonExpression rangeMinComparisonExpression = new ComparisonExpression(GREATER_THAN, cast.getExpression(), rangeMinLiteralExpression);
if (rangeMaxLiteralVsRoundtripped >= 0) {
return and(
rangeMinComparisonExpression,
new ComparisonExpression(LESS_THAN_OR_EQUAL, cast.getExpression(), rangeMaxLiteralExpression));
}
else {
return and(
rangeMinComparisonExpression,
// We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot
// produce a value smaller than the next value in the source type
new ComparisonExpression(LESS_THAN, cast.getExpression(), rangeMaxLiteralExpression));
}
}
else {
ComparisonExpression rangeMinComparisonExpression = new ComparisonExpression(GREATER_THAN_OR_EQUAL, cast.getExpression(), rangeMinLiteralExpression);
if (rangeMaxLiteralVsRoundtripped >= 0) {
return new BetweenPredicate(cast.getExpression(), rangeMinLiteralExpression, rangeMaxLiteralExpression);
}
else {
return and(
rangeMinComparisonExpression,
// We expect implicit coercions to be order-preserving, so the result of converting back from target -> source cannot
// produce a value smaller than the next value in the source type
new ComparisonExpression(LESS_THAN, cast.getExpression(), rangeMaxLiteralExpression));
}
}
}

private Expression unwrapTimestampToDateCastForRange(
Session session,
TimestampType sourceType,
Expression timestampExpression,
long minDateInclusive,
long maxDateInclusive)
{
ResolvedFunction targetToSource;
try {
targetToSource = plannerContext.getMetadata().getCoercion(session, DATE, sourceType);
}
catch (OperatorNotFoundException e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
}

Expression minDateInclusiveTimestamp = literalEncoder.toExpression(session, coerce(minDateInclusive, targetToSource), sourceType);
Expression maxDateInclusiveTimestamp;
if (sourceType.isShort()) {
long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSource);
long maxTimestampInclusiveMicros = maxDateExclusive - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_SHORT_PRECISION);
maxDateInclusiveTimestamp = literalEncoder.toExpression(session, maxTimestampInclusiveMicros, sourceType);
}
else {
ResolvedFunction targetToSourceShortTimestamp;
try {
targetToSourceShortTimestamp = plannerContext.getMetadata().getCoercion(session, DATE, createTimestampType(TimestampType.MAX_SHORT_PRECISION));
}
catch (OperatorNotFoundException e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
}
long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSourceShortTimestamp);
long maxTimestampInclusiveMicros = maxDateExclusive - 1;
int picosOfMicro = toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_PRECISION));
maxDateInclusiveTimestamp = literalEncoder.toExpression(
session,
new LongTimestamp(maxTimestampInclusiveMicros, picosOfMicro),
sourceType);
}

return new BetweenPredicate(timestampExpression, minDateInclusiveTimestamp, maxDateInclusiveTimestamp);
}

private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object value)
{
if ((source.equals(BIGINT) && target.equals(DOUBLE)) ||
Expand Down
Loading