Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ private OptionalDouble doubleValueFromLiteral(Type type, Expression literal)
session,
new AllowAllAccessControl(),
ImmutableMap.of());

if (literalValue == null) {
return OptionalDouble.empty();
}

return toStatsRepresentation(type, literalValue);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
Expand All @@ -50,6 +51,8 @@
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand Down Expand Up @@ -80,11 +83,13 @@
import static io.trino.SystemSessionProperties.isComplexExpressionPushdown;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME;
Expand Down Expand Up @@ -272,6 +277,10 @@ protected Optional<Expression> translateCall(Call call)
}
}

if (IN_PREDICATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 2) {
return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1));
}

QualifiedName name = QualifiedName.of(call.getFunctionName().getName());
List<TypeSignature> argumentTypes = call.getArguments().stream()
.map(argument -> argument.getType().getTypeSignature())
Expand Down Expand Up @@ -344,15 +353,8 @@ private Optional<Expression> translateCast(Type type, ConnectorExpression expres

private Optional<Expression> translateLogicalExpression(LogicalExpression.Operator operator, List<ConnectorExpression> arguments)
{
ImmutableList.Builder<Expression> translatedArguments = ImmutableList.builderWithExpectedSize(arguments.size());
for (ConnectorExpression argument : arguments) {
Optional<Expression> translated = translate(argument);
if (translated.isEmpty()) {
return Optional.empty();
}
translatedArguments.add(translated.get());
}
return Optional.of(new LogicalExpression(operator, translatedArguments.build()));
Optional<List<Expression>> translatedArguments = translateExpressions(arguments);
return translatedArguments.map(expressions -> new LogicalExpression(operator, expressions));
}

private Optional<Expression> translateComparison(ComparisonExpression.Operator operator, ConnectorExpression left, ConnectorExpression right)
Expand Down Expand Up @@ -446,6 +448,46 @@ protected Optional<Expression> translateLike(ConnectorExpression value, Connecto

return Optional.empty();
}

protected Optional<Expression> translateInPredicate(ConnectorExpression value, ConnectorExpression values)
{
Optional<Expression> translatedValue = translate(value);
Optional<List<Expression>> translatedValues = extractExpressionsFromArrayCall(values);

if (translatedValue.isPresent() && translatedValues.isPresent()) {
return Optional.of(new InPredicate(translatedValue.get(), new InListExpression(translatedValues.get())));
}

return Optional.empty();
}

protected Optional<List<Expression>> extractExpressionsFromArrayCall(ConnectorExpression expression)
{
if (!(expression instanceof Call)) {
return Optional.empty();
}

Call call = (Call) expression;
if (!call.getFunctionName().equals(ARRAY_CONSTRUCTOR_FUNCTION_NAME)) {
return Optional.empty();
}

return translateExpressions(call.getArguments());
}

protected Optional<List<Expression>> translateExpressions(List<ConnectorExpression> expressions)
{
ImmutableList.Builder<Expression> translatedExpressions = ImmutableList.builderWithExpectedSize(expressions.size());
for (ConnectorExpression expression : expressions) {
Optional<Expression> translated = translate(expression);
if (translated.isEmpty()) {
return Optional.empty();
}
translatedExpressions.add(translated.get());
}

return Optional.of(translatedExpressions.build());
}
}

public static class SqlToConnectorExpressionTranslator
Expand Down Expand Up @@ -760,6 +802,31 @@ protected Optional<ConnectorExpression> visitSubscriptExpression(SubscriptExpres
return Optional.of(new FieldDereference(typeOf(node), translatedBase.get(), toIntExact(((LongLiteral) node.getIndex()).getValue() - 1)));
}

@Override
protected Optional<ConnectorExpression> visitInPredicate(InPredicate node, Void context)
{
InListExpression valueList = (InListExpression) node.getValueList();
Optional<ConnectorExpression> valueExpression = process(node.getValue());

if (valueExpression.isEmpty()) {
return Optional.empty();
}

ImmutableList.Builder<ConnectorExpression> values = ImmutableList.builderWithExpectedSize(valueList.getValues().size());
for (Expression value : valueList.getValues()) {
Optional<ConnectorExpression> processedValue = process(value);

if (processedValue.isEmpty()) {
return Optional.empty();
}

values.add(processedValue.get());
}

ConnectorExpression arrayExpression = new Call(new ArrayType(typeOf(node.getValueList())), ARRAY_CONSTRUCTOR_FUNCTION_NAME, values.build());
return Optional.of(new Call(typeOf(node), IN_PREDICATE_FUNCTION_NAME, List.of(valueExpression.get(), arrayExpression)));
}

@Override
protected Optional<ConnectorExpression> visitExpression(Expression node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.tree.ArithmeticBinaryExpression;
Expand All @@ -34,6 +35,8 @@
import io.trino.sql.tree.DoubleLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
Expand All @@ -59,6 +62,7 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME;
Expand Down Expand Up @@ -94,6 +98,8 @@ public class TestConnectorExpressionTranslator
private static final TypeAnalyzer TYPE_ANALYZER = createTestingTypeAnalyzer(PLANNER_CONTEXT);
private static final Type ROW_TYPE = rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5)));
private static final VarcharType VARCHAR_TYPE = createVarcharType(25);
private static final ArrayType VARCHAR_ARRAY_TYPE = new ArrayType(VARCHAR_TYPE);

private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(PLANNER_CONTEXT);

private static final Map<Symbol, Type> symbols = ImmutableMap.<Symbol, Type>builder()
Expand Down Expand Up @@ -418,6 +424,22 @@ public void testTranslateRegularExpression()
});
}

@Test
public void testTranslateIn()
{
String value = "value_1";
assertTranslationRoundTrips(
new InPredicate(
new SymbolReference("varchar_symbol_1"),
new InListExpression(List.of(new SymbolReference("varchar_symbol_1"), new StringLiteral(value)))),
new Call(
BOOLEAN,
StandardFunctions.IN_PREDICATE_FUNCTION_NAME,
List.of(
new Variable("varchar_symbol_1", VARCHAR_TYPE),
new Call(VARCHAR_ARRAY_TYPE, ARRAY_CONSTRUCTOR_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), createVarcharType(value.length())))))));
}

private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression)
{
assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,16 @@ private StandardFunctions() {}
public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate");

public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern");

/**
* {@code $in(value, array)} returns {@code true} when value is equal to an element of the array,
* otherwise returns {@code NULL} when comparing value to an element of the array returns an
* indeterminate result, otherwise returns {@code false}
*/
public static final FunctionName IN_PREDICATE_FUNCTION_NAME = new FunctionName("$in");

/**
* $array creates instance of {@link Array Type}
*/
public static final FunctionName ARRAY_CONSTRUCTOR_FUNCTION_NAME = new FunctionName("$array");
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ public static Pattern<Call> call()
return Property.property("argumentCount", call -> call.getArguments().size());
}

public static Property<Call, ?, List<ConnectorExpression>> arguments()
{
return Property.property("arguments", Call::getArguments);
}

public static Property<Call, ?, ConnectorExpression> argument(int argument)
{
checkArgument(0 <= argument, "Invalid argument index: %s", argument);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.spi.type.Type;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

Expand All @@ -39,9 +41,14 @@ public static JdbcConnectorExpressionRewriterBuilder newBuilder()
private JdbcConnectorExpressionRewriterBuilder() {}

public JdbcConnectorExpressionRewriterBuilder addStandardRules(Function<String, String> identifierQuote)
{
return addStandardRules(identifierQuote, type -> Optional.empty());
}

public JdbcConnectorExpressionRewriterBuilder addStandardRules(Function<String, String> identifierQuote, Function<Type, Optional<String>> typeMapping)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simple for PostgreSQL and varchar case, but generally this is tricky, since remote database and Trino have different type systems.

In case of

where x IN (a, b, c, NULL)

the NULL should be eliminated on the engine side (generic engine rule).

Do we have other cases where this is actually useful?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x IN (NULL) is always false, right? Even for NULL IN (NULL) so eliminating NULLs is safe?

also x IN () (empty list) is also false

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x IN (NULL) is always false, right?

result is NULL

also x IN () (empty list) is also false

This is not allowed in SQL, but it it was allowed, it would be false.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the NULL should be eliminated on the engine side (generic engine rule).

BTW such rule, if written, should go in separate PR.

in this PR please keep the tests exercising IN pushdown with NULLs (asserting that these doesn't get pushed down).

{
add(new RewriteVariable(identifierQuote));
add(new RewriteVarcharConstant());
add(new RewriteVarcharConstant(typeMapping));
add(new RewriteExactNumericConstant());
add(new RewriteAnd());
add(new RewriteOr());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc.expression;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.type.Type;

import java.util.Optional;
import java.util.function.Function;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.spi.expression.StandardFunctions.CAST_FUNCTION_NAME;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class RewriteCast
implements ConnectorExpressionRule<Call, String>
{
private static final Capture<ConnectorExpression> ARGUMENT = newCapture();

private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(CAST_FUNCTION_NAME))
.with(argument(0).capturedAs(ARGUMENT));
Comment on lines +41 to +42
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fmt


private final Function<Type, Optional<String>> typeMapping;

public RewriteCast(Function<Type, Optional<String>> typeMapping)
{
this.typeMapping = requireNonNull(typeMapping, "typeMapping is null");
}

@Override
public Pattern<Call> getPattern()
{
return PATTERN;
}

@Override
public Optional<String> rewrite(Call expression, Captures captures, RewriteContext<String> context)
{
ConnectorExpression argument = captures.get(ARGUMENT);
Optional<String> typeCast = typeMapping.apply(expression.getType());

if (typeCast.isEmpty()) {
return Optional.empty();
}

Optional<String> translatedArgument = context.defaultRewrite(argument);
if (translatedArgument.isEmpty()) {
return Optional.empty();
}

return Optional.of(format("CAST(%s AS %s)", translatedArgument.get(), typeCast.get()));
}
}
Loading