-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Improve comparison predicate pushdown when varchar column cast to date #13567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,7 @@ | |
| import javax.annotation.Nullable; | ||
|
|
||
| import java.lang.invoke.MethodHandle; | ||
| import java.time.LocalDate; | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
|
|
@@ -81,19 +82,22 @@ | |
|
|
||
| import static com.google.common.base.Preconditions.checkArgument; | ||
| import static com.google.common.base.Preconditions.checkState; | ||
| import static com.google.common.base.Verify.verify; | ||
| import static com.google.common.collect.ImmutableList.toImmutableList; | ||
| import static com.google.common.collect.Iterables.getOnlyElement; | ||
| import static com.google.common.collect.Iterators.peekingIterator; | ||
| import static io.airlift.slice.SliceUtf8.countCodePoints; | ||
| import static io.airlift.slice.SliceUtf8.getCodePointAt; | ||
| import static io.airlift.slice.SliceUtf8.lengthOfCodePoint; | ||
| import static io.airlift.slice.SliceUtf8.setCodePointAt; | ||
| import static io.airlift.slice.Slices.utf8Slice; | ||
| import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; | ||
| import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; | ||
| import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; | ||
| import static io.trino.spi.function.InvocationConvention.simpleConvention; | ||
| import static io.trino.spi.function.OperatorType.SATURATED_FLOOR_CAST; | ||
| import static io.trino.spi.type.BooleanType.BOOLEAN; | ||
| import static io.trino.spi.type.DateType.DATE; | ||
| import static io.trino.spi.type.TypeUtils.isFloatingPointNaN; | ||
| import static io.trino.sql.ExpressionUtils.and; | ||
| import static io.trino.sql.ExpressionUtils.combineConjuncts; | ||
|
|
@@ -104,6 +108,7 @@ | |
| import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; | ||
| import static io.trino.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; | ||
|
|
@@ -502,6 +507,19 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, | |
| } | ||
| if (symbolExpression instanceof Cast) { | ||
| Cast castExpression = (Cast) symbolExpression; | ||
| // type of expression which is then cast to type of value | ||
| Type castSourceType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression.getExpression())), "No type for Cast source expression"); | ||
| Type castTargetType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression)), "No type for Cast target expression"); | ||
| if (castSourceType instanceof VarcharType && castTargetType == DATE && !castExpression.isSafe()) { | ||
| Optional<ExtractionResult> result = createVarcharCastToDateComparisonExtractionResult( | ||
| node, | ||
| (VarcharType) castSourceType, | ||
| normalized.getValue(), | ||
| complement); | ||
| if (result.isPresent()) { | ||
| return result.get(); | ||
| } | ||
| } | ||
| if (!isImplicitCoercion(expressionTypes, castExpression)) { | ||
| // | ||
| // we cannot use non-coercion cast to literal_type on symbol side to build tuple domain | ||
|
|
@@ -524,9 +542,6 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, | |
| return super.visitComparisonExpression(node, complement); | ||
| } | ||
|
|
||
| // type of expression which is then cast to type of value | ||
| Type castSourceType = requireNonNull(expressionTypes.get(NodeRef.of(castExpression.getExpression())), "No type for Cast source expression"); | ||
|
|
||
| // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side | ||
| Optional<Expression> coercedExpression = coerceComparisonWithRounding( | ||
| castSourceType, castExpression.getExpression(), normalized.getValue(), normalized.getComparisonOperator()); | ||
|
|
@@ -588,6 +603,115 @@ private Map<NodeRef<Expression>, Type> analyzeExpression(Expression expression) | |
| return typeAnalyzer.getTypes(session, types, expression); | ||
| } | ||
|
|
||
| private Optional<ExtractionResult> createVarcharCastToDateComparisonExtractionResult( | ||
| ComparisonExpression node, | ||
| VarcharType sourceType, | ||
| NullableValue value, | ||
| boolean complement) | ||
| { | ||
| Cast castExpression = (Cast) node.getLeft(); | ||
| Expression sourceExpression = castExpression.getExpression(); | ||
| ComparisonExpression.Operator comparisonOperator = node.getOperator(); | ||
| requireNonNull(value, "value is null"); | ||
|
|
||
| if (complement || value.isNull()) { | ||
| return Optional.empty(); | ||
| } | ||
| if (!(sourceExpression instanceof SymbolReference)) { | ||
| // Calculation is not useful | ||
| return Optional.empty(); | ||
| } | ||
| Symbol sourceSymbol = Symbol.from(sourceExpression); | ||
|
|
||
| if (!sourceType.isUnbounded() && sourceType.getBoundedLength() < 10) { | ||
| // too short | ||
| return Optional.empty(); | ||
| } | ||
|
|
||
| LocalDate date = LocalDate.ofEpochDay(((long) value.getValue())); | ||
| if (date.getYear() < 1001 || date.getYear() > 9998) { | ||
| // Edge cases. 1-year margin so that we can go to next/prev year for < or > comparisons | ||
| return Optional.empty(); | ||
| } | ||
|
|
||
| // superset of possible values, for the "normal case" | ||
| ValueSet valueSet; | ||
| boolean nullAllowed = false; | ||
|
|
||
| switch (comparisonOperator) { | ||
| case EQUAL: | ||
| valueSet = dateStringRanges(date, sourceType); | ||
| break; | ||
| case NOT_EQUAL: | ||
| case IS_DISTINCT_FROM: | ||
| if (date.getDayOfMonth() < 10) { | ||
| // TODO: possible to handle but cumbersome | ||
| return Optional.empty(); | ||
| } | ||
| valueSet = ValueSet.all(sourceType).subtract(dateStringRanges(date, sourceType)); | ||
| nullAllowed = (comparisonOperator == IS_DISTINCT_FROM); | ||
| break; | ||
| case LESS_THAN: | ||
| case LESS_THAN_OR_EQUAL: | ||
| valueSet = ValueSet.ofRanges(Range.lessThan(sourceType, utf8Slice(Integer.toString(date.getYear() + 1)))); | ||
| break; | ||
| case GREATER_THAN: | ||
| case GREATER_THAN_OR_EQUAL: | ||
| valueSet = ValueSet.ofRanges(Range.greaterThan(sourceType, utf8Slice(Integer.toString(date.getYear() - 1)))); | ||
| break; | ||
| default: | ||
| return Optional.empty(); | ||
| } | ||
|
|
||
| // Date representations starting with whitespace, sign or leading zeroes. | ||
| valueSet = valueSet.union(ValueSet.ofRanges( | ||
| Range.lessThan(sourceType, utf8Slice("1")), | ||
| Range.greaterThan(sourceType, utf8Slice("9")))); | ||
|
|
||
| return Optional.of(new ExtractionResult( | ||
| TupleDomain.withColumnDomains(ImmutableMap.of(sourceSymbol, Domain.create(valueSet, nullAllowed))), | ||
| node)); | ||
| } | ||
|
|
||
| /** | ||
| * @return Date representations of the form 2005-09-09, 2005-09-9, 2005-9-09 and 2005-9-9 expanded to ranges: | ||
| * {@code [2005-09-09, 2005-09-0:), [2005-09-9, 2005-09-:), [2005-9-09, 2005-9-0:), [2005-9-9, 2005-9-:)} | ||
| * (the {@code :} character is the next one after {@code 9}). | ||
| */ | ||
| private static SortedRangeSet dateStringRanges(LocalDate date, VarcharType domainType) | ||
| { | ||
| checkArgument(date.getYear() >= 1000 && date.getYear() <= 9999, "Unsupported date: %s", date); | ||
|
|
||
| int month = date.getMonthValue(); | ||
| int day = date.getDayOfMonth(); | ||
| boolean isMonthSingleDigit = date.getMonthValue() < 10; | ||
| boolean isDaySingleDigit = date.getDayOfMonth() < 10; | ||
|
|
||
| // A specific date value like 2005-09-10 can be a result of a CAST for number of various forms, | ||
| // as the value can have optional sign, leading zeros for the year, and surrounding whitespace, | ||
| // E.g. ' +002005-9-9 '. | ||
|
|
||
| List<Range> valueRanges = new ArrayList<>(4); | ||
| for (boolean useSingleDigitMonth : List.of(true, false)) { | ||
| for (boolean useSingleDigitDay : List.of(true, false)) { | ||
| if (useSingleDigitMonth && !isMonthSingleDigit) { | ||
| continue; | ||
| } | ||
| if (useSingleDigitDay && !isDaySingleDigit) { | ||
| continue; | ||
| } | ||
| String dateString = date.getYear() + | ||
| ((!useSingleDigitMonth && isMonthSingleDigit) ? "-0" : "-") + month + | ||
| ((!useSingleDigitDay && isDaySingleDigit) ? "-0" : "-") + day; | ||
| String nextStringPrefix = dateString.substring(0, dateString.length() - 1) + (char) (dateString.charAt(dateString.length() - 1) + 1); // cannot overflow | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sanity question. Char SQL type ordering uses the same order as Java char ordering? ie. if you add one to a char in java you also get the next char when it comes to SQL char sorting
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general -- no (i think it breaks for non-BMP) |
||
| verify(dateString.length() <= domainType.getLength().orElse(Integer.MAX_VALUE), "dateString length exceeds type bounds"); | ||
| verify(dateString.length() == nextStringPrefix.length(), "Next string length mismatch"); | ||
| valueRanges.add(Range.range(domainType, utf8Slice(dateString), true, utf8Slice(nextStringPrefix), false)); | ||
| } | ||
| } | ||
| return (SortedRangeSet) ValueSet.ofRanges(valueRanges); | ||
| } | ||
|
|
||
| private static Optional<ExtractionResult> createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) | ||
| { | ||
| if (value == null) { | ||
|
|
@@ -990,7 +1114,7 @@ private Optional<ExtractionResult> tryVisitLikePredicate(LikePredicate node, Boo | |
| VarcharType varcharType = (VarcharType) type; | ||
|
|
||
| Symbol symbol = Symbol.from(node.getValue()); | ||
| Slice pattern = Slices.utf8Slice(((StringLiteral) node.getPattern()).getValue()); | ||
| Slice pattern = utf8Slice(((StringLiteral) node.getPattern()).getValue()); | ||
| Optional<Slice> escape = node.getEscape() | ||
| .map(StringLiteral.class::cast) | ||
| .map(StringLiteral::getValue) | ||
|
|
@@ -1064,7 +1188,7 @@ private Optional<ExtractionResult> tryVisitStartsWithFunction(FunctionCall node, | |
| } | ||
|
|
||
| Symbol symbol = Symbol.from(target); | ||
| Slice constantPrefix = Slices.utf8Slice(((StringLiteral) prefix).getValue()); | ||
| Slice constantPrefix = utf8Slice(((StringLiteral) prefix).getValue()); | ||
|
|
||
| return createRangeDomain(type, constantPrefix).map(domain -> new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(symbol, domain)), node)); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| /* | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package io.trino.testing.assertions; | ||
|
|
||
| import java.util.function.Consumer; | ||
| import java.util.function.Supplier; | ||
|
|
||
| import static java.util.Objects.requireNonNull; | ||
|
|
||
| public final class TestUtil | ||
| { | ||
| private TestUtil() {} | ||
|
|
||
| public static <T> void verifyResultOrFailure(Supplier<T> callback, Consumer<T> verifyResults, Consumer<Throwable> verifyFailure) | ||
| { | ||
| requireNonNull(callback, "callback is null"); | ||
| requireNonNull(verifyResults, "verifyResults is null"); | ||
| requireNonNull(verifyFailure, "verifyFailure is null"); | ||
|
|
||
| T result; | ||
| try { | ||
| result = callback.get(); | ||
| } | ||
| catch (Throwable t) { | ||
| verifyFailure.accept(t); | ||
| return; | ||
| } | ||
| verifyResults.accept(result); | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.