Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import javax.annotation.Nullable;

import java.lang.invoke.MethodHandle;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -81,19 +82,22 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Iterators.peekingIterator;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.airlift.slice.SliceUtf8.getCodePointAt;
import static io.airlift.slice.SliceUtf8.lengthOfCodePoint;
import static io.airlift.slice.SliceUtf8.setCodePointAt;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.function.OperatorType.SATURATED_FLOOR_CAST;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.TypeUtils.isFloatingPointNaN;
import static io.trino.sql.ExpressionUtils.and;
import static io.trino.sql.ExpressionUtils.combineConjuncts;
Expand All @@ -104,6 +108,7 @@
import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL;
Expand Down Expand Up @@ -502,6 +507,19 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node,
}
if (symbolExpression instanceof Cast) {
Cast castExpression = (Cast) symbolExpression;
// type of expression which is then cast to type of value
Type castSourceType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression.getExpression())), "No type for Cast source expression");
Type castTargetType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression)), "No type for Cast target expression");
if (castSourceType instanceof VarcharType && castTargetType == DATE && !castExpression.isSafe()) {
Optional<ExtractionResult> result = createVarcharCastToDateComparisonExtractionResult(
node,
(VarcharType) castSourceType,
normalized.getValue(),
complement);
if (result.isPresent()) {
return result.get();
}
}
if (!isImplicitCoercion(expressionTypes, castExpression)) {
//
// we cannot use non-coercion cast to literal_type on symbol side to build tuple domain
Expand All @@ -524,9 +542,6 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node,
return super.visitComparisonExpression(node, complement);
}

// type of expression which is then cast to type of value
Type castSourceType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression.getExpression())), "No type for Cast source expression");

// we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side
Optional<Expression> coercedExpression = coerceComparisonWithRounding(
castSourceType, castExpression.getExpression(), normalized.getValue(), normalized.getComparisonOperator());
Expand Down Expand Up @@ -588,6 +603,115 @@ private Map<NodeRef<Expression>, Type> analyzeExpression(Expression expression)
return typeAnalyzer.getTypes(session, types, expression);
}

private Optional<ExtractionResult> createVarcharCastToDateComparisonExtractionResult(
ComparisonExpression node,
VarcharType sourceType,
NullableValue value,
boolean complement)
{
Cast castExpression = (Cast) node.getLeft();
Expression sourceExpression = castExpression.getExpression();
ComparisonExpression.Operator comparisonOperator = node.getOperator();
requireNonNull(value, "value is null");

if (complement || value.isNull()) {
return Optional.empty();
}
if (!(sourceExpression instanceof SymbolReference)) {
// Calculation is not useful
return Optional.empty();
}
Symbol sourceSymbol = Symbol.from(sourceExpression);

if (!sourceType.isUnbounded() && sourceType.getBoundedLength() < 10) {
// too short
return Optional.empty();
}

LocalDate date = LocalDate.ofEpochDay(((long) value.getValue()));
if (date.getYear() < 1001 || date.getYear() > 9998) {
// Edge cases. 1-year margin so that we can go to next/prev year for < or > comparisons
return Optional.empty();
}

// superset of possible values, for the "normal case"
ValueSet valueSet;
boolean nullAllowed = false;

switch (comparisonOperator) {
case EQUAL:
valueSet = dateStringRanges(date, sourceType);
break;
case NOT_EQUAL:
case IS_DISTINCT_FROM:
if (date.getDayOfMonth() < 10) {
// TODO: possible to handle but cumbersome
return Optional.empty();
}
valueSet = ValueSet.all(sourceType).subtract(dateStringRanges(date, sourceType));
nullAllowed = (comparisonOperator == IS_DISTINCT_FROM);
break;
case LESS_THAN:
case LESS_THAN_OR_EQUAL:
valueSet = ValueSet.ofRanges(Range.lessThan(sourceType, utf8Slice(Integer.toString(date.getYear() + 1))));
break;
case GREATER_THAN:
case GREATER_THAN_OR_EQUAL:
valueSet = ValueSet.ofRanges(Range.greaterThan(sourceType, utf8Slice(Integer.toString(date.getYear() - 1))));
break;
default:
return Optional.empty();
}

// Date representations starting with whitespace, sign or leading zeroes.
valueSet = valueSet.union(ValueSet.ofRanges(
Range.lessThan(sourceType, utf8Slice("1")),
Comment thread
findepi marked this conversation as resolved.
Outdated
Range.greaterThan(sourceType, utf8Slice("9"))));

return Optional.of(new ExtractionResult(
TupleDomain.withColumnDomains(ImmutableMap.of(sourceSymbol, Domain.create(valueSet, nullAllowed))),
node));
}

/**
* @return Date representations of the form 2005-09-09, 2005-09-9, 2005-9-09 and 2005-9-9 expanded to ranges:
* {@code [2005-09-09, 2005-09-0:), [2005-09-9, 2005-09-:), [2005-9-09, 2005-9-0:), [2005-9-9, 2005-9-:)}
* (the {@code :} character is the next one after {@code 9}).
*/
private static SortedRangeSet dateStringRanges(LocalDate date, VarcharType domainType)
{
checkArgument(date.getYear() >= 1000 && date.getYear() <= 9999, "Unsupported date: %s", date);

int month = date.getMonthValue();
int day = date.getDayOfMonth();
boolean isMonthSingleDigit = date.getMonthValue() < 10;
boolean isDaySingleDigit = date.getDayOfMonth() < 10;

// A specific date value like 2005-09-10 can be a result of a CAST for number of various forms,
// as the value can have optional sign, leading zeros for the year, and surrounding whitespace,
// E.g. ' +002005-9-9 '.

List<Range> valueRanges = new ArrayList<>(4);
for (boolean useSingleDigitMonth : List.of(true, false)) {
for (boolean useSingleDigitDay : List.of(true, false)) {
if (useSingleDigitMonth && !isMonthSingleDigit) {
continue;
}
if (useSingleDigitDay && !isDaySingleDigit) {
continue;
}
String dateString = date.getYear() +
((!useSingleDigitMonth && isMonthSingleDigit) ? "-0" : "-") + month +
((!useSingleDigitDay && isDaySingleDigit) ? "-0" : "-") + day;
String nextStringPrefix = dateString.substring(0, dateString.length() - 1) + (char) (dateString.charAt(dateString.length() - 1) + 1); // cannot overflow
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity question. Char SQL type ordering uses the same order as Java char ordering? ie. if you add one to a char in java you also get the next char when it comes to SQL char sorting

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general -- no (i think it breaks for non-BMP)
Here, we're operating within ASCII (digits and hyphen). '9' + 1 produces ':', still ASCII.

verify(dateString.length() <= domainType.getLength().orElse(Integer.MAX_VALUE), "dateString length exceeds type bounds");
verify(dateString.length() == nextStringPrefix.length(), "Next string length mismatch");
valueRanges.add(Range.range(domainType, utf8Slice(dateString), true, utf8Slice(nextStringPrefix), false));
}
}
return (SortedRangeSet) ValueSet.ofRanges(valueRanges);
}

private static Optional<ExtractionResult> createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement)
{
if (value == null) {
Expand Down Expand Up @@ -990,7 +1114,7 @@ private Optional<ExtractionResult> tryVisitLikePredicate(LikePredicate node, Boo
VarcharType varcharType = (VarcharType) type;

Symbol symbol = Symbol.from(node.getValue());
Slice pattern = Slices.utf8Slice(((StringLiteral) node.getPattern()).getValue());
Slice pattern = utf8Slice(((StringLiteral) node.getPattern()).getValue());
Optional<Slice> escape = node.getEscape()
.map(StringLiteral.class::cast)
.map(StringLiteral::getValue)
Expand Down Expand Up @@ -1064,7 +1188,7 @@ private Optional<ExtractionResult> tryVisitStartsWithFunction(FunctionCall node,
}

Symbol symbol = Symbol.from(target);
Slice constantPrefix = Slices.utf8Slice(((StringLiteral) prefix).getValue());
Slice constantPrefix = utf8Slice(((StringLiteral) prefix).getValue());

return createRangeDomain(type, constantPrefix).map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.testing.assertions;

import java.util.function.Consumer;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;

public final class TestUtil
{
private TestUtil() {}

public static <T> void verifyResultOrFailure(Supplier<T> callback, Consumer<T> verifyResults, Consumer<Throwable> verifyFailure)
{
requireNonNull(callback, "callback is null");
requireNonNull(verifyResults, "verifyResults is null");
requireNonNull(verifyFailure, "verifyFailure is null");

T result;
try {
result = callback.get();
}
catch (Throwable t) {
verifyFailure.accept(t);
return;
}
verifyResults.accept(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ public static Slice castToVarchar(@LiteralParameter("x") long x, @SqlType(Standa
@SqlType(StandardTypes.DATE)
public static long castFromVarchar(@SqlType("varchar(x)") Slice value)
{
// Note: update DomainTranslator.Visitor.createVarcharCastToDateComparisonExtractionResult whenever CAST behavior changes.

try {
return parseDate(trim(value).toStringUtf8());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ private DateTimeUtils() {}

public static int parseDate(String value)
{
// Note: update DomainTranslator.Visitor.createVarcharCastToDateComparisonExtractionResult whenever varchar->date conversion (CAST) behavior changes.

// in order to follow the standard, we should validate the value:
// - the required format is 'YYYY-MM-DD'
// - all components should be unsigned numbers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,90 @@ public void testFromComparisonsWithCoercions()
assertPredicateIsAlwaysFalse(not(isDistinctFrom(cast(C_INTEGER, DOUBLE), doubleLiteral(2.1))));
}

@Test
public void testPredicateWithVarcharCastToDate()
{
// =
assertPredicateDerives(
equal(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", " +2005-9-10 \t")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("1")),
Range.range(VARCHAR, utf8Slice("2005-09-10"), true, utf8Slice("2005-09-11"), false),
Range.range(VARCHAR, utf8Slice("2005-9-10"), true, utf8Slice("2005-9-11"), false),
Range.greaterThan(VARCHAR, utf8Slice("9"))),
false)));
// = with day ending with 9
assertPredicateDerives(
equal(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-09")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("1")),
Range.range(VARCHAR, utf8Slice("2005-09-09"), true, utf8Slice("2005-09-0:"), false),
Range.range(VARCHAR, utf8Slice("2005-09-9"), true, utf8Slice("2005-09-:"), false),
Range.range(VARCHAR, utf8Slice("2005-9-09"), true, utf8Slice("2005-9-0:"), false),
Range.range(VARCHAR, utf8Slice("2005-9-9"), true, utf8Slice("2005-9-:"), false),
Range.greaterThan(VARCHAR, utf8Slice("9"))),
false)));
assertPredicateDerives(
equal(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-19")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("1")),
Range.range(VARCHAR, utf8Slice("2005-09-19"), true, utf8Slice("2005-09-1:"), false),
Range.range(VARCHAR, utf8Slice("2005-9-19"), true, utf8Slice("2005-9-1:"), false),
Range.greaterThan(VARCHAR, utf8Slice("9"))),
false)));

// !=
assertPredicateDerives(
notEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", " +2005-9-10 \t")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("2005-09-10")),
Range.range(VARCHAR, utf8Slice("2005-09-11"), true, utf8Slice("2005-9-10"), false),
Range.greaterThanOrEqual(VARCHAR, utf8Slice("2005-9-11"))),
false)));

// != with single-digit day
assertUnsupportedPredicate(
notEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", " +2005-9-2 \t")));
// != with day ending with 9
assertUnsupportedPredicate(
notEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-09")));
assertPredicateDerives(
notEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-19")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("2005-09-19")),
Range.range(VARCHAR, utf8Slice("2005-09-1:"), true, utf8Slice("2005-9-19"), false),
Range.greaterThanOrEqual(VARCHAR, utf8Slice("2005-9-1:"))),
false)));

// <
assertPredicateDerives(
lessThan(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", " +2005-9-10 \t")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("2006")),
Range.greaterThan(VARCHAR, utf8Slice("9"))),
false)));

// >
assertPredicateDerives(
greaterThan(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", " +2005-9-10 \t")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("1")),
Range.greaterThan(VARCHAR, utf8Slice("2004"))),
false)));

// BETWEEN
assertPredicateTranslates(
between(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2001-01-31"), new GenericLiteral("DATE", "2005-09-10")),
tupleDomain(C_VARCHAR, Domain.create(ValueSet.ofRanges(
Range.lessThan(VARCHAR, utf8Slice("1")),
Range.range(VARCHAR, utf8Slice("2000"), false, utf8Slice("2006"), false),
Range.greaterThan(VARCHAR, utf8Slice("9"))),
false)),
and(
greaterThanOrEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2001-01-31")),
lessThanOrEqual(cast(C_VARCHAR, DATE), new GenericLiteral("DATE", "2005-09-10"))));
}

@Test
public void testFromUnprocessableInPredicate()
{
Expand Down Expand Up @@ -1911,6 +1995,11 @@ private void assertPredicateTranslates(Expression expression, TupleDomain<Symbol
assertPredicateTranslates(expression, tupleDomain, TRUE_LITERAL);
}

private void assertPredicateDerives(Expression expression, TupleDomain<Symbol> tupleDomain)
{
assertPredicateTranslates(expression, tupleDomain, expression);
}

private void assertPredicateTranslates(Expression expression, TupleDomain<Symbol> tupleDomain, Expression remainingExpression)
{
ExtractionResult result = fromPredicate(expression);
Expand Down
Loading