Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import io.trino.sql.DynamicFilters;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BinaryLiteral;
import io.trino.sql.tree.BooleanLiteral;
Expand Down Expand Up @@ -67,18 +69,25 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.trino.SystemSessionProperties.isComplexExpressionPushdown;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.MODULUS_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
Expand Down Expand Up @@ -124,6 +133,24 @@ static FunctionName functionNameForComparisonOperator(ComparisonExpression.Opera
throw new UnsupportedOperationException("Unsupported operator: " + operator);
}

@VisibleForTesting
static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpression.Operator operator)
{
switch (operator) {
case ADD:
return ADD_FUNCTION_NAME;
case SUBTRACT:
return SUBTRACT_FUNCTION_NAME;
case MULTIPLY:
return MULTIPLY_FUNCTION_NAME;
case DIVIDE:
return DIVIDE_FUNCTION_NAME;
case MODULUS:
return MODULUS_FUNCTION_NAME;
}
throw new UnsupportedOperationException("Unsupported operator: " + operator);
}

private static class ConnectorToSqlExpressionTranslator
{
private final Session session;
Expand Down Expand Up @@ -176,13 +203,27 @@ protected Optional<Expression> translateCall(Call call)
return translateLogicalExpression(LogicalExpression.Operator.OR, call.getArguments());
}

// comparisons
if (call.getArguments().size() == 2) {
Optional<ComparisonExpression.Operator> operator = comparisonOperatorForFunctionName(call.getFunctionName());
if (operator.isPresent()) {
return translateComparison(operator.get(), call.getArguments().get(0), call.getArguments().get(1));
}
}

// arithmetic binary
if (call.getArguments().size() == 2) {
Optional<ArithmeticBinaryExpression.Operator> operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName());
if (operator.isPresent()) {
return translateArithmeticBinary(operator.get(), call.getArguments().get(0), call.getArguments().get(1));
}
}

// arithmetic unary
if (NEGATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) {
return translate(getOnlyElement(call.getArguments())).map(argument -> new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, argument));
}

if (LIKE_PATTERN_FUNCTION_NAME.equals(call.getFunctionName())) {
switch (call.getArguments().size()) {
case 2:
Expand Down Expand Up @@ -256,6 +297,33 @@ private Optional<ComparisonExpression.Operator> comparisonOperatorForFunctionNam
return Optional.empty();
}

private Optional<Expression> translateArithmeticBinary(ArithmeticBinaryExpression.Operator operator, ConnectorExpression left, ConnectorExpression right)
{
return translate(left).flatMap(leftTranslated ->
translate(right).map(rightTranslated ->
new ArithmeticBinaryExpression(operator, leftTranslated, rightTranslated)));
}

private Optional<ArithmeticBinaryExpression.Operator> arithmeticBinaryOperatorForFunctionName(FunctionName functionName)
{
if (ADD_FUNCTION_NAME.equals(functionName)) {
return Optional.of(ArithmeticBinaryExpression.Operator.ADD);
}
if (SUBTRACT_FUNCTION_NAME.equals(functionName)) {
return Optional.of(ArithmeticBinaryExpression.Operator.SUBTRACT);
}
if (MULTIPLY_FUNCTION_NAME.equals(functionName)) {
return Optional.of(ArithmeticBinaryExpression.Operator.MULTIPLY);
}
if (DIVIDE_FUNCTION_NAME.equals(functionName)) {
return Optional.of(ArithmeticBinaryExpression.Operator.DIVIDE);
}
if (MODULUS_FUNCTION_NAME.equals(functionName)) {
return Optional.of(ArithmeticBinaryExpression.Operator.MODULUS);
}
return Optional.empty();
}

protected Optional<Expression> translateLike(ConnectorExpression value, ConnectorExpression pattern, Optional<ConnectorExpression> escape)
{
Optional<Expression> translatedValue = translate(value);
Expand Down Expand Up @@ -390,6 +458,31 @@ protected Optional<ConnectorExpression> visitComparisonExpression(ComparisonExpr
new Call(typeOf(node), functionNameForComparisonOperator(node.getOperator()), ImmutableList.of(left, right))));
}

@Override
protected Optional<ConnectorExpression> visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
{
if (!isComplexExpressionPushdown(session)) {
return Optional.empty();
}
return process(node.getLeft()).flatMap(left -> process(node.getRight()).map(right ->
Comment thread
findepi marked this conversation as resolved.
Outdated
new Call(typeOf(node), functionNameForArithmeticBinaryOperator(node.getOperator()), ImmutableList.of(left, right))));
}

@Override
protected Optional<ConnectorExpression> visitArithmeticUnary(ArithmeticUnaryExpression node, Void context)
{
if (!isComplexExpressionPushdown(session)) {
return Optional.empty();
}
switch (node.getSign()) {
case PLUS:
return process(node.getValue());
case MINUS:
return process(node.getValue()).map(value -> new Call(typeOf(node), NEGATE_FUNCTION_NAME, ImmutableList.of(value)));
}
throw new UnsupportedOperationException("Unsupported sign: " + node.getSign());
}

@Override
protected Optional<ConnectorExpression> visitCast(Cast node, Void context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.Type;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.LikePredicate;
Expand All @@ -47,6 +49,7 @@
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DecimalType.createDecimalType;
Expand Down Expand Up @@ -177,6 +180,41 @@ public static Object[][] testTranslateComparisonExpressionDataProvider()
.collect(toDataProvider());
}

@Test(dataProvider = "testTranslateArithmeticBinaryDataProvider")
public void testTranslateArithmeticBinary(ArithmeticBinaryExpression.Operator operator)
Comment thread
hashhar marked this conversation as resolved.
Outdated
{
assertTranslationRoundTrips(
new ArithmeticBinaryExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")),
new Call(
DOUBLE,
ConnectorExpressionTranslator.functionNameForArithmeticBinaryOperator(operator),
List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE))));
}

@DataProvider
public static Object[][] testTranslateArithmeticBinaryDataProvider()
{
return Stream.of(ArithmeticBinaryExpression.Operator.values())
.collect(toDataProvider());
}

@Test
public void testTranslateArithmeticUnaryMinus()
{
assertTranslationRoundTrips(
new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("double_symbol_1")),
new Call(DOUBLE, NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE))));
}

@Test
public void testTranslateArithmeticUnaryPlus()
{
assertTranslationToConnectorExpression(
TEST_SESSION,
new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.PLUS, new SymbolReference("double_symbol_1")),
new Variable("double_symbol_1", DOUBLE));
}

@Test
public void testTranslateLike()
{
Expand Down Expand Up @@ -234,6 +272,11 @@ private void assertTranslationRoundTrips(Session session, Expression expression,
assertTranslationFromConnectorExpression(session, connectorExpression, expression);
}

private void assertTranslationToConnectorExpression(Session session, Expression expression, ConnectorExpression connectorExpression)
{
assertTranslationToConnectorExpression(session, expression, Optional.of(connectorExpression));
}

private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional<ConnectorExpression> connectorExpression)
{
Optional<ConnectorExpression> translation = translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.type.Type;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.AtTimeZone;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.LongLiteral;
Expand All @@ -39,6 +41,8 @@
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RowType.field;
import static io.trino.spi.type.RowType.rowType;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.planner.ConnectorExpressionTranslator.translate;
import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations;
Expand All @@ -58,7 +62,11 @@ public class TestPartialTranslator
.put(new Symbol("double_symbol_1"), DOUBLE)
.put(new Symbol("double_symbol_2"), DOUBLE)
.put(new Symbol("bigint_symbol_1"), BIGINT)
.put(new Symbol("row_symbol_1"), rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5))))
.put(new Symbol("timestamp3_symbol_1"), TIMESTAMP_MILLIS)
.put(new Symbol("row_symbol_1"), rowType(
field("int_symbol_1", INTEGER),
field("varchar_symbol_1", createVarcharType(5)),
field("timestamptz3_field_1", TIMESTAMP_TZ_MILLIS)))
.buildOrThrow());

@Test
Expand All @@ -67,15 +75,21 @@ public void testPartialTranslator()
Expression rowSymbolReference = new SymbolReference("row_symbol_1");
Expression dereferenceExpression1 = new SubscriptExpression(rowSymbolReference, new LongLiteral("1"));
Expression dereferenceExpression2 = new SubscriptExpression(rowSymbolReference, new LongLiteral("2"));
Expression dereferenceExpression3 = new SubscriptExpression(rowSymbolReference, new LongLiteral("3"));
Expression stringLiteral = new StringLiteral("abcd");
Expression symbolReference1 = new SymbolReference("double_symbol_1");
SymbolReference timestamp3SymbolReference = new SymbolReference("timestamp3_symbol_1");

assertFullTranslation(symbolReference1);
assertFullTranslation(dereferenceExpression1);
assertFullTranslation(stringLiteral);
assertFullTranslation(new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1));

Expression binaryExpression = new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1);
assertPartialTranslation(binaryExpression, ImmutableList.of(symbolReference1, dereferenceExpression1));
assertPartialTranslation(
new CoalesceExpression(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i replaced + with coalesce as an example of an expression that cannot be translated.
This is OK for now, but we need a better testing strategy, see eg #11535
cc @ebyhr

new AtTimeZone(timestamp3SymbolReference, stringLiteral),
dereferenceExpression3),
List.of(timestamp3SymbolReference, stringLiteral, dereferenceExpression3));

List<Expression> functionArguments = ImmutableList.of(stringLiteral, dereferenceExpression2);
Expression functionCallExpression = new FunctionCall(QualifiedName.of("concat"), functionArguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@
import io.trino.spi.type.Type;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.SymbolReference;
Expand All @@ -56,8 +59,10 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.spi.predicate.Domain.singleValue;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout;
Expand Down Expand Up @@ -217,13 +222,16 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider()
.functionCallBuilder(QualifiedName.of("rand"))
.build(),
new GenericLiteral("BIGINT", "42")),
new ComparisonExpression(
EQUAL,
new ArithmeticBinaryExpression(
MODULUS,
new SymbolReference("nationkey"),
new GenericLiteral("BIGINT", "17")),
new GenericLiteral("BIGINT", "44")),
// non-translatable to connector expression
new CoalesceExpression(
new Cast(new NullLiteral(), toSqlType(BOOLEAN)),
new ComparisonExpression(
EQUAL,
new ArithmeticBinaryExpression(
MODULUS,
new SymbolReference("nationkey"),
new GenericLiteral("BIGINT", "17")),
new GenericLiteral("BIGINT", "44"))),
LogicalExpression.or(
new ComparisonExpression(
EQUAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,35 @@ private StandardFunctions() {}
public static final FunctionName GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$greater_than_or_equal");
public static final FunctionName IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME = new FunctionName("$is_distinct_from");

/**
* Arithmetic addition.
*/
public static final FunctionName ADD_FUNCTION_NAME = new FunctionName("$add");

/**
* Arithmetic subtraction.
*/
public static final FunctionName SUBTRACT_FUNCTION_NAME = new FunctionName("$subtract");

/**
* Arithmetic multiplication.
*/
public static final FunctionName MULTIPLY_FUNCTION_NAME = new FunctionName("$multiply");

/**
* Arithmetic division.
*/
public static final FunctionName DIVIDE_FUNCTION_NAME = new FunctionName("$divide");

/**
* Arithmetic modulus.
*/
public static final FunctionName MODULUS_FUNCTION_NAME = new FunctionName("$modulus");

/**
* Arithmetic unary minus.
*/
public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate");

public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern");
}