From 4ae14d63d25b57f0afd231d487149265035873a1 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 17 Mar 2022 13:35:28 +0100 Subject: [PATCH 1/2] Add another partial translation test case --- .../sql/planner/TestPartialTranslator.java | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 27106281e000..13f5a77643a5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -19,6 +19,8 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.Type; import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.AtTimeZone; +import io.trino.sql.tree.CoalesceExpression; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.LongLiteral; @@ -39,6 +41,8 @@ import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.RowType.field; import static io.trino.spi.type.RowType.rowType; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.PartialTranslator.extractPartialTranslations; @@ -58,7 +62,11 @@ public class TestPartialTranslator .put(new Symbol("double_symbol_1"), DOUBLE) .put(new Symbol("double_symbol_2"), DOUBLE) .put(new Symbol("bigint_symbol_1"), BIGINT) - .put(new Symbol("row_symbol_1"), rowType(field("int_symbol_1", INTEGER), field("varchar_symbol_1", createVarcharType(5)))) + .put(new Symbol("timestamp3_symbol_1"), TIMESTAMP_MILLIS) + .put(new Symbol("row_symbol_1"), rowType( + field("int_symbol_1", INTEGER), + field("varchar_symbol_1", createVarcharType(5)), + field("timestamptz3_field_1", TIMESTAMP_TZ_MILLIS))) .buildOrThrow()); @Test @@ -67,15 +75,24 @@ public void testPartialTranslator() Expression rowSymbolReference = new SymbolReference("row_symbol_1"); Expression dereferenceExpression1 = new SubscriptExpression(rowSymbolReference, new LongLiteral("1")); Expression dereferenceExpression2 = new SubscriptExpression(rowSymbolReference, new LongLiteral("2")); + Expression dereferenceExpression3 = new SubscriptExpression(rowSymbolReference, new LongLiteral("3")); Expression stringLiteral = new StringLiteral("abcd"); Expression symbolReference1 = new SymbolReference("double_symbol_1"); + SymbolReference timestamp3SymbolReference = new SymbolReference("timestamp3_symbol_1"); assertFullTranslation(symbolReference1); assertFullTranslation(dereferenceExpression1); assertFullTranslation(stringLiteral); - Expression binaryExpression = new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1); - assertPartialTranslation(binaryExpression, ImmutableList.of(symbolReference1, dereferenceExpression1)); + assertPartialTranslation( + new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1), + List.of(symbolReference1, dereferenceExpression1)); + + assertPartialTranslation( + new CoalesceExpression( + new AtTimeZone(timestamp3SymbolReference, stringLiteral), + dereferenceExpression3), + List.of(timestamp3SymbolReference, stringLiteral, dereferenceExpression3)); List functionArguments = ImmutableList.of(stringLiteral, dereferenceExpression2); Expression functionCallExpression = new FunctionCall(QualifiedName.of("concat"), functionArguments); From 8d04de9b9ce0ae78b00a5a9f6dddbea29319b4d7 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 17 Mar 2022 15:15:30 +0100 Subject: [PATCH 2/2] Translate arithmetic to connector expression --- .../ConnectorExpressionTranslator.java | 93 +++++++++++++++++++ .../TestConnectorExpressionTranslator.java | 43 +++++++++ .../sql/planner/TestPartialTranslator.java | 5 +- .../rule/TestPushPredicateIntoTableScan.java | 22 +++-- .../spi/expression/StandardFunctions.java | 30 ++++++ 5 files changed, 182 insertions(+), 11 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 65bd10992e37..b9fee5704757 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -36,6 +36,8 @@ import io.trino.sql.DynamicFilters; import io.trino.sql.PlannerContext; import io.trino.sql.analyzer.TypeSignatureProvider; +import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.BinaryLiteral; import io.trino.sql.tree.BooleanLiteral; @@ -67,9 +69,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.SliceUtf8.countCodePoints; import static io.trino.SystemSessionProperties.isComplexExpressionPushdown; +import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.DIVIDE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; @@ -77,8 +82,12 @@ import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.MODULUS_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.MULTIPLY_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.OR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral; import static io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression; @@ -124,6 +133,24 @@ static FunctionName functionNameForComparisonOperator(ComparisonExpression.Opera throw new UnsupportedOperationException("Unsupported operator: " + operator); } + @VisibleForTesting + static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpression.Operator operator) + { + switch (operator) { + case ADD: + return ADD_FUNCTION_NAME; + case SUBTRACT: + return SUBTRACT_FUNCTION_NAME; + case MULTIPLY: + return MULTIPLY_FUNCTION_NAME; + case DIVIDE: + return DIVIDE_FUNCTION_NAME; + case MODULUS: + return MODULUS_FUNCTION_NAME; + } + throw new UnsupportedOperationException("Unsupported operator: " + operator); + } + private static class ConnectorToSqlExpressionTranslator { private final Session session; @@ -176,6 +203,7 @@ protected Optional translateCall(Call call) return translateLogicalExpression(LogicalExpression.Operator.OR, call.getArguments()); } + // comparisons if (call.getArguments().size() == 2) { Optional operator = comparisonOperatorForFunctionName(call.getFunctionName()); if (operator.isPresent()) { @@ -183,6 +211,19 @@ protected Optional translateCall(Call call) } } + // arithmetic binary + if (call.getArguments().size() == 2) { + Optional operator = arithmeticBinaryOperatorForFunctionName(call.getFunctionName()); + if (operator.isPresent()) { + return translateArithmeticBinary(operator.get(), call.getArguments().get(0), call.getArguments().get(1)); + } + } + + // arithmetic unary + if (NEGATE_FUNCTION_NAME.equals(call.getFunctionName()) && call.getArguments().size() == 1) { + return translate(getOnlyElement(call.getArguments())).map(argument -> new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, argument)); + } + if (LIKE_PATTERN_FUNCTION_NAME.equals(call.getFunctionName())) { switch (call.getArguments().size()) { case 2: @@ -256,6 +297,33 @@ private Optional comparisonOperatorForFunctionNam return Optional.empty(); } + private Optional translateArithmeticBinary(ArithmeticBinaryExpression.Operator operator, ConnectorExpression left, ConnectorExpression right) + { + return translate(left).flatMap(leftTranslated -> + translate(right).map(rightTranslated -> + new ArithmeticBinaryExpression(operator, leftTranslated, rightTranslated))); + } + + private Optional arithmeticBinaryOperatorForFunctionName(FunctionName functionName) + { + if (ADD_FUNCTION_NAME.equals(functionName)) { + return Optional.of(ArithmeticBinaryExpression.Operator.ADD); + } + if (SUBTRACT_FUNCTION_NAME.equals(functionName)) { + return Optional.of(ArithmeticBinaryExpression.Operator.SUBTRACT); + } + if (MULTIPLY_FUNCTION_NAME.equals(functionName)) { + return Optional.of(ArithmeticBinaryExpression.Operator.MULTIPLY); + } + if (DIVIDE_FUNCTION_NAME.equals(functionName)) { + return Optional.of(ArithmeticBinaryExpression.Operator.DIVIDE); + } + if (MODULUS_FUNCTION_NAME.equals(functionName)) { + return Optional.of(ArithmeticBinaryExpression.Operator.MODULUS); + } + return Optional.empty(); + } + protected Optional translateLike(ConnectorExpression value, ConnectorExpression pattern, Optional escape) { Optional translatedValue = translate(value); @@ -390,6 +458,31 @@ protected Optional visitComparisonExpression(ComparisonExpr new Call(typeOf(node), functionNameForComparisonOperator(node.getOperator()), ImmutableList.of(left, right)))); } + @Override + protected Optional visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) + { + if (!isComplexExpressionPushdown(session)) { + return Optional.empty(); + } + return process(node.getLeft()).flatMap(left -> process(node.getRight()).map(right -> + new Call(typeOf(node), functionNameForArithmeticBinaryOperator(node.getOperator()), ImmutableList.of(left, right)))); + } + + @Override + protected Optional visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) + { + if (!isComplexExpressionPushdown(session)) { + return Optional.empty(); + } + switch (node.getSign()) { + case PLUS: + return process(node.getValue()); + case MINUS: + return process(node.getValue()).map(value -> new Call(typeOf(node), NEGATE_FUNCTION_NAME, ImmutableList.of(value))); + } + throw new UnsupportedOperationException("Unsupported sign: " + node.getSign()); + } + @Override protected Optional visitCast(Cast node, Void context) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index a3dee601f15d..ff562b40fd3f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -25,6 +25,8 @@ import io.trino.spi.expression.StandardFunctions; import io.trino.spi.expression.Variable; import io.trino.spi.type.Type; +import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.Expression; import io.trino.sql.tree.LikePredicate; @@ -47,6 +49,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; @@ -177,6 +180,41 @@ public static Object[][] testTranslateComparisonExpressionDataProvider() .collect(toDataProvider()); } + @Test(dataProvider = "testTranslateArithmeticBinaryDataProvider") + public void testTranslateArithmeticBinary(ArithmeticBinaryExpression.Operator operator) + { + assertTranslationRoundTrips( + new ArithmeticBinaryExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")), + new Call( + DOUBLE, + ConnectorExpressionTranslator.functionNameForArithmeticBinaryOperator(operator), + List.of(new Variable("double_symbol_1", DOUBLE), new Variable("double_symbol_2", DOUBLE)))); + } + + @DataProvider + public static Object[][] testTranslateArithmeticBinaryDataProvider() + { + return Stream.of(ArithmeticBinaryExpression.Operator.values()) + .collect(toDataProvider()); + } + + @Test + public void testTranslateArithmeticUnaryMinus() + { + assertTranslationRoundTrips( + new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("double_symbol_1")), + new Call(DOUBLE, NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DOUBLE)))); + } + + @Test + public void testTranslateArithmeticUnaryPlus() + { + assertTranslationToConnectorExpression( + TEST_SESSION, + new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.PLUS, new SymbolReference("double_symbol_1")), + new Variable("double_symbol_1", DOUBLE)); + } + @Test public void testTranslateLike() { @@ -234,6 +272,11 @@ private void assertTranslationRoundTrips(Session session, Expression expression, assertTranslationFromConnectorExpression(session, connectorExpression, expression); } + private void assertTranslationToConnectorExpression(Session session, Expression expression, ConnectorExpression connectorExpression) + { + assertTranslationToConnectorExpression(session, expression, Optional.of(connectorExpression)); + } + private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional connectorExpression) { Optional translation = translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, PLANNER_CONTEXT); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java index 13f5a77643a5..ce2fe443243d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPartialTranslator.java @@ -83,10 +83,7 @@ public void testPartialTranslator() assertFullTranslation(symbolReference1); assertFullTranslation(dereferenceExpression1); assertFullTranslation(stringLiteral); - - assertPartialTranslation( - new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1), - List.of(symbolReference1, dereferenceExpression1)); + assertFullTranslation(new ArithmeticBinaryExpression(ADD, symbolReference1, dereferenceExpression1)); assertPartialTranslation( new CoalesceExpression( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java index b7c2ec5ffc13..558fcc9c9f92 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateIntoTableScan.java @@ -39,10 +39,13 @@ import io.trino.spi.type.Type; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.tree.ArithmeticBinaryExpression; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.CoalesceExpression; import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.LogicalExpression; import io.trino.sql.tree.LongLiteral; +import io.trino.sql.tree.NullLiteral; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.StringLiteral; import io.trino.sql.tree.SymbolReference; @@ -56,8 +59,10 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.constrainedTableScanWithTableLayout; @@ -217,13 +222,16 @@ public void doesNotConsumeRemainingPredicateIfNewDomainIsWider() .functionCallBuilder(QualifiedName.of("rand")) .build(), new GenericLiteral("BIGINT", "42")), - new ComparisonExpression( - EQUAL, - new ArithmeticBinaryExpression( - MODULUS, - new SymbolReference("nationkey"), - new GenericLiteral("BIGINT", "17")), - new GenericLiteral("BIGINT", "44")), + // non-translatable to connector expression + new CoalesceExpression( + new Cast(new NullLiteral(), toSqlType(BOOLEAN)), + new ComparisonExpression( + EQUAL, + new ArithmeticBinaryExpression( + MODULUS, + new SymbolReference("nationkey"), + new GenericLiteral("BIGINT", "17")), + new GenericLiteral("BIGINT", "44"))), LogicalExpression.or( new ComparisonExpression( EQUAL, diff --git a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java index e8f0f4510fc5..a7ae39b861de 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java +++ b/core/trino-spi/src/main/java/io/trino/spi/expression/StandardFunctions.java @@ -35,5 +35,35 @@ private StandardFunctions() {} public static final FunctionName GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME = new FunctionName("$greater_than_or_equal"); public static final FunctionName IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME = new FunctionName("$is_distinct_from"); + /** + * Arithmetic addition. + */ + public static final FunctionName ADD_FUNCTION_NAME = new FunctionName("$add"); + + /** + * Arithmetic subtraction. + */ + public static final FunctionName SUBTRACT_FUNCTION_NAME = new FunctionName("$subtract"); + + /** + * Arithmetic multiplication. + */ + public static final FunctionName MULTIPLY_FUNCTION_NAME = new FunctionName("$multiply"); + + /** + * Arithmetic division. + */ + public static final FunctionName DIVIDE_FUNCTION_NAME = new FunctionName("$divide"); + + /** + * Arithmetic modulus. + */ + public static final FunctionName MODULUS_FUNCTION_NAME = new FunctionName("$modulus"); + + /** + * Arithmetic unary minus. + */ + public static final FunctionName NEGATE_FUNCTION_NAME = new FunctionName("$negate"); + public static final FunctionName LIKE_PATTERN_FUNCTION_NAME = new FunctionName("$like_pattern"); }