Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import io.trino.spi.TrinoException;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateType;
Comment thread
findinpath marked this conversation as resolved.
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.RealType;
import io.trino.spi.type.TimeWithTimeZoneType;
Expand All @@ -39,6 +41,7 @@
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
Expand Down Expand Up @@ -68,7 +71,9 @@
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochMillisAndFraction;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.TimestampType.createTimestampType;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND;
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
import static io.trino.sql.ExpressionUtils.and;
Expand All @@ -81,6 +86,8 @@
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
import static io.trino.type.DateTimes.PICOSECONDS_PER_MICROSECOND;
import static io.trino.type.DateTimes.scaleFactor;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -171,6 +178,51 @@ public Expression rewriteComparisonExpression(ComparisonExpression node, Void co
return unwrapCast(expression);
}

@Override
public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
Comment thread
findinpath marked this conversation as resolved.
{
BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null);
return unwrapCast(expression);
}

private Expression unwrapCast(BetweenPredicate expression)
{
// Canonicalization is handled by CanonicalizeExpressionRewriter
if (!(expression.getValue() instanceof Cast cast)) {
return expression;
}

Object min = new ExpressionInterpreter(expression.getMin(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMin()))
.optimize(NoOpSymbolResolver.INSTANCE);
Object max = new ExpressionInterpreter(expression.getMax(), plannerContext, session, typeAnalyzer.getTypes(session, types, expression.getMax()))
.optimize(NoOpSymbolResolver.INSTANCE);

if (min == null || min instanceof NullLiteral || max == null || max instanceof NullLiteral) {
return new Cast(new NullLiteral(), toSqlType(BOOLEAN));
}

if (min instanceof Expression || max instanceof Expression) {
return expression;
}

Type sourceType = typeAnalyzer.getType(session, types, cast.getExpression());
Type minType = typeAnalyzer.getType(session, types, expression.getMin());
Type maxType = typeAnalyzer.getType(session, types, expression.getMax());
verify(minType.equals(maxType), "Mismatched types: %s and %s", minType, maxType);

if (sourceType instanceof TimestampType && minType == DATE) {
Comment thread
findinpath marked this conversation as resolved.
return unwrapTimestampToDateCastForRange(session, (TimestampType) sourceType, cast.getExpression(), (long) min, (long) max);
}
if (!hasInjectiveImplicitCoercion(sourceType, minType, min) || !hasInjectiveImplicitCoercion(sourceType, maxType, max)) {
return expression;
}
if (sourceType instanceof DateType && minType instanceof TimestampType) {
return unwrapDateToTimestampCastForRange(expression, cast, min, max, sourceType, minType);
}

return expression;
}

private Expression unwrapCast(ComparisonExpression expression)
{
// Canonicalization is handled by CanonicalizeExpressionRewriter
Expand Down Expand Up @@ -387,6 +439,81 @@ private Optional<Expression> unwrapTimestampToDateCast(Session session, Timestam
};
}

private Expression unwrapTimestampToDateCastForRange(
Session session,
TimestampType sourceType,
Expression timestampExpression,
long minDateInclusive,
long maxDateInclusive)
{
ResolvedFunction targetToSource;
try {
targetToSource = plannerContext.getMetadata().getCoercion(session, DATE, sourceType);
}
catch (OperatorNotFoundException e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
}

Expression minDateInclusiveTimestamp = literalEncoder.toExpression(session, coerce(minDateInclusive, targetToSource), sourceType);
Expression maxDateInclusiveTimestamp;
if (sourceType.isShort()) {
long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSource);
long maxTimestampInclusiveMicros = maxDateExclusive - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_SHORT_PRECISION);
maxDateInclusiveTimestamp = literalEncoder.toExpression(session, maxTimestampInclusiveMicros, sourceType);
}
else {
ResolvedFunction targetToSourceShortTimestamp;
try {
targetToSourceShortTimestamp = plannerContext.getMetadata().getCoercion(session, DATE, createTimestampType(TimestampType.MAX_SHORT_PRECISION));
}
catch (OperatorNotFoundException e) {
throw new TrinoException(GENERIC_INTERNAL_ERROR, e);
}
long maxDateExclusive = (long) coerce(maxDateInclusive + 1, targetToSourceShortTimestamp);
long maxTimestampInclusiveMicros = maxDateExclusive - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_SHORT_PRECISION);
int picosOfMicro = toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(sourceType.getPrecision(), TimestampType.MAX_PRECISION));
maxDateInclusiveTimestamp = literalEncoder.toExpression(
session,
new LongTimestamp(maxTimestampInclusiveMicros, picosOfMicro),
sourceType);
}

return new BetweenPredicate(timestampExpression, minDateInclusiveTimestamp, maxDateInclusiveTimestamp);
}

private BetweenPredicate unwrapDateToTimestampCastForRange(BetweenPredicate expression, Cast cast, Object min, Object max, Type sourceType, Type minType)
{
ResolvedFunction targetToSource;
try {
targetToSource = plannerContext.getMetadata().getCoercion(session, minType, sourceType);
}
catch (OperatorNotFoundException e) {
// Without a cast between target -> source, there's nothing more we can do
return expression;
}

Object minLiteralInSourceType;
Object maxLiteralInSourceType;
try {
minLiteralInSourceType = coerce(min, targetToSource);
maxLiteralInSourceType = coerce(max, targetToSource);
}
catch (TrinoException e) {
// A failure to cast from target -> source type could be because:
// 1. missing cast
// 2. bad implementation
// 3. out of range or otherwise unrepresentable value
// Since we can't distinguish between those cases, take the conservative option
// and bail out.
return expression;
}

return new BetweenPredicate(
cast.getExpression(),
literalEncoder.toExpression(session, minLiteralInSourceType, sourceType),
literalEncoder.toExpression(session, maxLiteralInSourceType, sourceType));
}

private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object value)
{
if ((source.equals(BIGINT) && target.equals(DOUBLE)) ||
Expand Down Expand Up @@ -506,7 +633,7 @@ private static Object withTimeZone(TimestampWithTimeZoneType type, Object value,
return packDateTimeWithZone(unpackMillisUtc((long) value), newZone);
}
LongTimestampWithTimeZone longTimestampWithTimeZone = (LongTimestampWithTimeZone) value;
return LongTimestampWithTimeZone.fromEpochMillisAndFraction(longTimestampWithTimeZone.getEpochMillis(), longTimestampWithTimeZone.getPicosOfMilli(), newZone);
return fromEpochMillisAndFraction(longTimestampWithTimeZone.getEpochMillis(), longTimestampWithTimeZone.getPicosOfMilli(), newZone);
}

private static TimeZoneKey getTimeZone(TimestampWithTimeZoneType type, Object value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,93 @@ public Expression rewriteComparisonExpression(ComparisonExpression node, Void co
return unwrapDateTrunc(expression);
}

@Override
public Expression rewriteBetweenPredicate(BetweenPredicate node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
BetweenPredicate expression = (BetweenPredicate) treeRewriter.defaultRewrite((Expression) node, null);
return unwrapDateTrunc(expression);
}

/**
* Given constant temporal unit U, the constant expressions:
* <ul>
* <li>tmin</li>
* <li>tmax</li>
* </ul>
* and epsilon as the minimum unit of precision for the temporal
* type of the expression dt, rewrite expression of the form
* <pre>date_trunc(U, dt) BETWEEN tmin AND tmax</pre>
* <p>
* into
* <pre>dt BETWEEN tmin AND floor(tmax, U) + U - epsilon</pre>
* <p>
*/
private Expression unwrapDateTrunc(BetweenPredicate expression)
{
if (!(expression.getValue() instanceof FunctionCall call) ||
!extractFunctionName(call.getName()).equals("date_trunc") ||
call.getArguments().size() != 2) {
return expression;
}

Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression);
Expression unitExpression = call.getArguments().get(0);
if (!(expressionTypes.get(NodeRef.of(unitExpression)) instanceof VarcharType) || !isEffectivelyLiteral(plannerContext, session, unitExpression)) {
return expression;
}
Slice unitName = (Slice) new ExpressionInterpreter(unitExpression, plannerContext, session, expressionTypes)
.optimize(NoOpSymbolResolver.INSTANCE);
if (unitName == null) {
return expression;
}

Expression argument = call.getArguments().get(1);
Type argumentType = expressionTypes.get(NodeRef.of(argument));

Type minType = expressionTypes.get(NodeRef.of(expression.getMin()));
Type maxType = expressionTypes.get(NodeRef.of(expression.getMax()));
verify(argumentType.equals(minType), "Mismatched types: %s and %s", argumentType, minType);
verify(argumentType.equals(maxType), "Mismatched types: %s and %s", argumentType, maxType);

Object min = new ExpressionInterpreter(expression.getMin(), plannerContext, session, expressionTypes)
.optimize(NoOpSymbolResolver.INSTANCE);
Object max = new ExpressionInterpreter(expression.getMax(), plannerContext, session, expressionTypes)
.optimize(NoOpSymbolResolver.INSTANCE);

if (min == null || min instanceof NullLiteral || max == null || max instanceof NullLiteral) {
return new Cast(new NullLiteral(), toSqlType(BOOLEAN));
}

if (min instanceof Expression || max instanceof Expression) {
return expression;
}
if (minType instanceof TimestampWithTimeZoneType || maxType instanceof TimestampWithTimeZoneType) {
// Cannot replace with a range due to how date_trunc operates on value's local date/time.
// I.e. unwrapping is possible only when values are all of some fixed zone and the zone is known.
return expression;
}

Optional<SupportedUnit> unitIfSupported = Enums.getIfPresent(SupportedUnit.class, unitName.toStringUtf8().toUpperCase(Locale.ENGLISH)).toJavaUtil();
if (unitIfSupported.isEmpty()) {
return expression;
}
SupportedUnit unit = unitIfSupported.get();
if (minType == DATE && (unit == SupportedUnit.DAY || unit == SupportedUnit.HOUR)) {
// DAY case handled by CanonicalizeExpressionRewriter, other is illegal, will fail
return expression;
}

ResolvedFunction resolvedFunction = plannerContext.getMetadata().decodeFunction(call.getName());
Object rangeHigh = functionInvoker.invoke(resolvedFunction, session.toConnectorSession(), ImmutableList.of(unitName, max));
int compareMax = compare(maxType, rangeHigh, max);
verify(compareMax <= 0, "Truncation of %s value %s resulted in a bigger value %s", maxType, max, rangeHigh);

return new BetweenPredicate(
argument,
toExpression(min, minType),
toExpression(calculateRangeEndInclusive(rangeHigh, maxType, unit), maxType));
}

// Simplify `date_trunc(unit, d) ? value`
private Expression unwrapDateTrunc(ComparisonExpression expression)
{
Expand Down Expand Up @@ -281,12 +368,12 @@ private Object calculateRangeEndInclusive(Object rangeStart, Type type, Supporte
};
long endExclusiveMicros = endExclusive.toEpochSecond(ZoneOffset.UTC) * MICROSECONDS_PER_SECOND
+ LongMath.divide(endExclusive.getNano(), NANOSECONDS_PER_MICROSECOND, UNNECESSARY);
return endExclusiveMicros - scaleFactor(timestampType.getPrecision(), 6);
return endExclusiveMicros - scaleFactor(timestampType.getPrecision(), TimestampType.MAX_SHORT_PRECISION);
}
LongTimestamp longTimestamp = (LongTimestamp) rangeStart;
verify(longTimestamp.getPicosOfMicro() == 0, "Unexpected picos in %s, value not rounded to %s", rangeStart, rangeUnit);
long endInclusiveMicros = (long) calculateRangeEndInclusive(longTimestamp.getEpochMicros(), createTimestampType(6), rangeUnit);
return new LongTimestamp(endInclusiveMicros, toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(timestampType.getPrecision(), 12)));
long endInclusiveMicros = (long) calculateRangeEndInclusive(longTimestamp.getEpochMicros(), createTimestampType(TimestampType.MAX_SHORT_PRECISION), rangeUnit);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the variable name is "endInclusiveMicros"
the code used 6 and it's know that 10^(-6)s is a microsecond.

after the change the code uses TimestampType.MAX_SHORT_PRECISION. it's not obvious that it's correct (is short precision actually microseconds?). Thus, actually this change decreases readability

return new LongTimestamp(endInclusiveMicros, toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(timestampType.getPrecision(), TimestampType.MAX_PRECISION)));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar here. the use PICOSECONDS_PER_MICROSECOND mandates that we know we're dealing with picoseconds, i.e. 10^(-12)s, so it matched the corresponding 12 on this line

after the change, we invoke "max precision" constant, but we still rely on it having an actual value of 12

}
throw new UnsupportedOperationException("Unsupported type: " + type);
}
Expand Down
Loading