From 65ea74383b91e936df9d2839c17a7244e5e04481 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 8 Dec 2023 15:47:35 +0100 Subject: [PATCH 1/6] Fix code indentation --- .../jdbc/TestDefaultJdbcQueryBuilder.java | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java index 0c04b64d4ad7..c58a9b5633d2 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java @@ -194,41 +194,41 @@ public void testNormalBuildSql() { TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.builder() .put(columns.get(0), Domain.create(SortedRangeSet.copyOf(BIGINT, - ImmutableList.of( - Range.equal(BIGINT, 128L), - Range.equal(BIGINT, 180L), - Range.equal(BIGINT, 233L), - Range.lessThan(BIGINT, 25L), - Range.range(BIGINT, 66L, true, 96L, true), - Range.greaterThan(BIGINT, 192L))), + ImmutableList.of( + Range.equal(BIGINT, 128L), + Range.equal(BIGINT, 180L), + Range.equal(BIGINT, 233L), + Range.lessThan(BIGINT, 25L), + Range.range(BIGINT, 66L, true, 96L, true), + Range.greaterThan(BIGINT, 192L))), false)) .put(columns.get(1), Domain.create(SortedRangeSet.copyOf(DOUBLE, - ImmutableList.of( - Range.equal(DOUBLE, 200011.0), - Range.equal(DOUBLE, 200014.0), - Range.equal(DOUBLE, 200017.0), - Range.equal(DOUBLE, 200116.5), - Range.range(DOUBLE, 200030.0, true, 200036.0, true), - Range.range(DOUBLE, 200048.0, true, 200099.0, true))), + ImmutableList.of( + Range.equal(DOUBLE, 200011.0), + Range.equal(DOUBLE, 200014.0), + Range.equal(DOUBLE, 200017.0), + Range.equal(DOUBLE, 200116.5), + Range.range(DOUBLE, 200030.0, true, 200036.0, true), + Range.range(DOUBLE, 200048.0, true, 200099.0, true))), false)) .put(columns.get(7), Domain.create(SortedRangeSet.copyOf(TINYINT, - ImmutableList.of( - Range.range(TINYINT, 60L, true, 70L, false), - Range.range(TINYINT, 52L, true, 55L, false))), + ImmutableList.of( + Range.range(TINYINT, 60L, true, 70L, false), + Range.range(TINYINT, 52L, true, 55L, false))), false)) .put(columns.get(8), Domain.create(SortedRangeSet.copyOf(SMALLINT, - ImmutableList.of( - Range.range(SMALLINT, -75L, true, -68L, true), - Range.range(SMALLINT, -200L, true, -100L, false))), + ImmutableList.of( + Range.range(SMALLINT, -75L, true, -68L, true), + Range.range(SMALLINT, -200L, true, -100L, false))), false)) .put(columns.get(9), Domain.create(SortedRangeSet.copyOf(INTEGER, - ImmutableList.of( - Range.equal(INTEGER, 80L), - Range.equal(INTEGER, 96L), - Range.lessThan(INTEGER, 0L))), + ImmutableList.of( + Range.equal(INTEGER, 80L), + Range.equal(INTEGER, 96L), + Range.lessThan(INTEGER, 0L))), false)) .put(columns.get(2), Domain.create(SortedRangeSet.copyOf(BOOLEAN, - ImmutableList.of(Range.equal(BOOLEAN, true))), + ImmutableList.of(Range.equal(BOOLEAN, true))), false)) .buildOrThrow()); @@ -309,10 +309,10 @@ public void testBuildSqlWithFloat() { TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( columns.get(10), Domain.create(SortedRangeSet.copyOf(REAL, - ImmutableList.of( - Range.equal(REAL, (long) floatToRawIntBits(100.0f + 0)), - Range.equal(REAL, (long) floatToRawIntBits(100.008f + 0)), - Range.equal(REAL, (long) floatToRawIntBits(100.0f + 14)))), + ImmutableList.of( + Range.equal(REAL, (long) floatToRawIntBits(100.0f + 0)), + Range.equal(REAL, (long) floatToRawIntBits(100.008f + 0)), + Range.equal(REAL, (long) floatToRawIntBits(100.0f + 14)))), false))); Connection connection = database.getConnection(); @@ -343,10 +343,10 @@ public void testBuildSqlWithVarchar() { TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( columns.get(3), Domain.create(SortedRangeSet.copyOf(VARCHAR, - ImmutableList.of( - Range.range(VARCHAR, utf8Slice("test_str_700"), true, utf8Slice("test_str_702"), false), - Range.equal(VARCHAR, utf8Slice("test_str_180")), - Range.equal(VARCHAR, utf8Slice("test_str_196")))), + ImmutableList.of( + Range.range(VARCHAR, utf8Slice("test_str_700"), true, utf8Slice("test_str_702"), false), + Range.equal(VARCHAR, utf8Slice("test_str_180")), + Range.equal(VARCHAR, utf8Slice("test_str_196")))), false))); Connection connection = database.getConnection(); @@ -379,10 +379,10 @@ public void testBuildSqlWithChar() CharType charType = CharType.createCharType(0); TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( columns.get(11), Domain.create(SortedRangeSet.copyOf(charType, - ImmutableList.of( - Range.range(charType, utf8Slice("test_str_700"), true, utf8Slice("test_str_702"), false), - Range.equal(charType, utf8Slice("test_str_180")), - Range.equal(charType, utf8Slice("test_str_196")))), + ImmutableList.of( + Range.range(charType, utf8Slice("test_str_700"), true, utf8Slice("test_str_702"), false), + Range.equal(charType, utf8Slice("test_str_180")), + Range.equal(charType, utf8Slice("test_str_196")))), false))); Connection connection = database.getConnection(); @@ -419,16 +419,16 @@ public void testBuildSqlWithDateTime() { TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( columns.get(4), Domain.create(SortedRangeSet.copyOf(DATE, - ImmutableList.of( - Range.range(DATE, toDays(2016, 6, 7), true, toDays(2016, 6, 17), false), - Range.equal(DATE, toDays(2016, 6, 3)), - Range.equal(DATE, toDays(2016, 10, 21)))), + ImmutableList.of( + Range.range(DATE, toDays(2016, 6, 7), true, toDays(2016, 6, 17), false), + Range.equal(DATE, toDays(2016, 6, 3)), + Range.equal(DATE, toDays(2016, 10, 21)))), false), columns.get(5), Domain.create(SortedRangeSet.copyOf(TIME_MILLIS, - ImmutableList.of( - Range.range(TIME_MILLIS, toTimeRepresentation(6, 12, 23), false, toTimeRepresentation(8, 23, 37), true), - Range.equal(TIME_MILLIS, toTimeRepresentation(2, 3, 4)), - Range.equal(TIME_MILLIS, toTimeRepresentation(20, 23, 37)))), + ImmutableList.of( + Range.range(TIME_MILLIS, toTimeRepresentation(6, 12, 23), false, toTimeRepresentation(8, 23, 37), true), + Range.equal(TIME_MILLIS, toTimeRepresentation(2, 3, 4)), + Range.equal(TIME_MILLIS, toTimeRepresentation(20, 23, 37)))), false))); Connection connection = database.getConnection(); @@ -472,10 +472,10 @@ public void testBuildSqlWithTimestamp() { TupleDomain tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of( columns.get(6), Domain.create(SortedRangeSet.copyOf(TIMESTAMP_MILLIS, - ImmutableList.of( - Range.equal(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 6, 3, 0, 23, 37)), - Range.equal(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 10, 19, 16, 23, 37)), - Range.range(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 6, 7, 8, 23, 37), false, toTrinoTimestamp(2016, 6, 9, 12, 23, 37), true))), + ImmutableList.of( + Range.equal(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 6, 3, 0, 23, 37)), + Range.equal(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 10, 19, 16, 23, 37)), + Range.range(TIMESTAMP_MILLIS, toTrinoTimestamp(2016, 6, 7, 8, 23, 37), false, toTrinoTimestamp(2016, 6, 9, 12, 23, 37), true))), false))); Connection connection = database.getConnection(); From 9e02e7d8f55c33724787dcc34614d9959a558a35 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 19 Dec 2023 14:50:28 +0100 Subject: [PATCH 2/6] Add some traceability which option is failing --- .../java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java | 7 +++++++ .../plugin/postgresql/TestPostgreSqlConnectorTest.java | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 97c158051b2b..40eb2f8da22e 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -14,6 +14,7 @@ package io.trino.plugin.jdbc; import com.google.common.collect.ImmutableList; +import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.Session; import io.trino.spi.QueryId; @@ -124,6 +125,8 @@ public abstract class BaseJdbcConnectorTest extends BaseConnectorTest { + private static final Logger log = Logger.get(BaseJdbcConnectorTest.class); + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(getClass().getName())); protected abstract SqlExecutor onRemoteDatabase(); @@ -1196,6 +1199,8 @@ public void testJoinPushdown() "nation_lowercase", "AS SELECT nationkey, lower(name) name, regionkey FROM nation")) { for (JoinOperator joinOperator : JoinOperator.values()) { + log.info("Testing joinOperator=%s", joinOperator); + if (joinOperator == FULL_JOIN && !hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN)) { assertThat(query(session, "SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.regionkey = r.regionkey")) .joinIsNotFullyPushedDown(); @@ -1255,6 +1260,7 @@ public void testJoinPushdown() // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join for (String operator : nonEqualities) { + log.info("Testing [joinOperator=%s] operator=%s on number", joinOperator, operator); assertJoinConditionallyPushedDown( session, format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", joinOperator, operator), @@ -1263,6 +1269,7 @@ public void testJoinPushdown() // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join for (String operator : nonEqualities) { + log.info("Testing [joinOperator=%s] operator=%s on varchar", joinOperator, operator); assertJoinConditionallyPushedDown( session, format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator), diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 8170cdcff5f2..77d01fc9de47 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.log.Logger; import io.airlift.units.Duration; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; @@ -80,6 +81,8 @@ public class TestPostgreSqlConnectorTest extends BaseJdbcConnectorTest { + private static final Logger log = Logger.get(TestPostgreSqlConnectorTest.class); + protected TestingPostgreSqlServer postgreSqlServer; @Override @@ -641,6 +644,8 @@ public void testStringJoinPushdownWithCollate() // inequality for (String operator : nonEqualities) { + log.info("Testing operator=%s", operator); + // bigint inequality predicate assertThat(query(withoutDynamicFiltering, format("SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey %s r.regionkey", operator))) // Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes @@ -654,6 +659,7 @@ public void testStringJoinPushdownWithCollate() // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join for (String operator : nonEqualities) { + log.info("Testing operator=%s", operator); assertConditionallyPushedDown( session, format("SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", operator), @@ -663,6 +669,7 @@ public void testStringJoinPushdownWithCollate() // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join for (String operator : nonEqualities) { + log.info("Testing operator=%s", operator); assertConditionallyPushedDown( session, format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", nationLowercaseTable.getName(), operator), From ed8b358491c6f8a842c6e5f07124985ba3a7518d Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 20 Dec 2023 10:50:03 +0100 Subject: [PATCH 3/6] Move extractVariables utility --- .../io/trino/metadata/MetadataManager.java | 4 +- .../sql/planner/ConnectorExpressions.java | 46 ------------------- .../base/expression/ConnectorExpressions.java | 22 +++++++++ 3 files changed, 24 insertions(+), 48 deletions(-) delete mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressions.java diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 9c9221e38ebf..67e5666663a3 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -112,7 +112,6 @@ import io.trino.spi.type.TypeOperators; import io.trino.sql.analyzer.TypeSignatureProvider; import io.trino.sql.parser.SqlParser; -import io.trino.sql.planner.ConnectorExpressions; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.tree.QualifiedName; import io.trino.transaction.TransactionManager; @@ -158,6 +157,7 @@ import static io.trino.metadata.RedirectionAwareTableHandle.noRedirection; import static io.trino.metadata.RedirectionAwareTableHandle.withRedirectionTo; import static io.trino.metadata.SignatureBinder.applyBoundVariables; +import static io.trino.plugin.base.expression.ConnectorExpressions.extractVariables; import static io.trino.spi.ErrorType.EXTERNAL; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; @@ -2040,7 +2040,7 @@ private void verifyProjection(TableHandle table, List proje .map(Assignment::getVariable) .collect(toImmutableSet()); projections.stream() - .flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream()) + .flatMap(connectorExpression -> extractVariables(connectorExpression).stream()) .map(Variable::getName) .filter(variableName -> !assignedVariables.contains(variableName)) .findAny() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressions.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressions.java deleted file mode 100644 index c7ecbec76528..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressions.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.planner; - -import com.google.common.graph.SuccessorsFunction; -import com.google.common.graph.Traverser; -import io.trino.spi.expression.ConnectorExpression; -import io.trino.spi.expression.Variable; - -import java.util.List; -import java.util.stream.Stream; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Streams.stream; -import static java.util.Objects.requireNonNull; - -public final class ConnectorExpressions -{ - private ConnectorExpressions() {} - - public static List extractVariables(ConnectorExpression expression) - { - return preOrder(expression) - .filter(Variable.class::isInstance) - .map(Variable.class::cast) - .collect(toImmutableList()); - } - - public static Stream preOrder(ConnectorExpression expression) - { - return stream( - Traverser.forTree((SuccessorsFunction) ConnectorExpression::getChildren) - .depthFirstPreOrder(requireNonNull(expression, "expression is null"))); - } -} diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java index 01fc9a38cf61..3971a2995cc2 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java @@ -14,21 +14,36 @@ package io.trino.plugin.base.expression; import com.google.common.collect.ImmutableList; +import com.google.common.graph.SuccessorsFunction; +import com.google.common.graph.Traverser; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Variable; import java.util.Arrays; import java.util.List; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.Streams.stream; import static io.trino.spi.expression.Constant.TRUE; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static java.util.Objects.requireNonNull; public final class ConnectorExpressions { private ConnectorExpressions() {} + public static List extractVariables(ConnectorExpression expression) + { + return preOrder(expression) + .filter(Variable.class::isInstance) + .map(Variable.class::cast) + .collect(toImmutableList()); + } + public static List extractConjuncts(ConnectorExpression expression) { ImmutableList.Builder resultBuilder = ImmutableList.builder(); @@ -64,4 +79,11 @@ public static ConnectorExpression and(List expressions) } return getOnlyElement(expressions); } + + private static Stream preOrder(ConnectorExpression expression) + { + return stream( + Traverser.forTree((SuccessorsFunction) ConnectorExpression::getChildren) + .depthFirstPreOrder(requireNonNull(expression, "expression is null"))); + } } From 4f1cc7f15842cb10f052e731a301c416581bba4a Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 22 Dec 2023 00:02:37 +0100 Subject: [PATCH 4/6] Pull RewriteComparison.ComparisonOperator as top level class --- .../jdbc/expression/ComparisonOperator.java | 70 +++++++++++++++++++ .../jdbc/expression/RewriteComparison.java | 52 -------------- .../expression/TestRewriteComparison.java | 2 +- .../io/trino/plugin/ignite/IgniteClient.java | 3 +- .../trino/plugin/phoenix5/PhoenixClient.java | 3 +- .../plugin/postgresql/PostgreSqlClient.java | 3 +- .../plugin/sqlserver/SqlServerClient.java | 3 +- 7 files changed, 79 insertions(+), 57 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ComparisonOperator.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ComparisonOperator.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ComparisonOperator.java new file mode 100644 index 000000000000..f9fa74844289 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/ComparisonOperator.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import io.trino.spi.expression.FunctionName; + +import java.util.Map; +import java.util.stream.Stream; + +import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public enum ComparisonOperator +{ + EQUAL(EQUAL_OPERATOR_FUNCTION_NAME, "="), + NOT_EQUAL(NOT_EQUAL_OPERATOR_FUNCTION_NAME, "<>"), + LESS_THAN(LESS_THAN_OPERATOR_FUNCTION_NAME, "<"), + LESS_THAN_OR_EQUAL(LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, "<="), + GREATER_THAN(GREATER_THAN_OPERATOR_FUNCTION_NAME, ">"), + GREATER_THAN_OR_EQUAL(GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ">="), + IS_DISTINCT_FROM(IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME, "IS DISTINCT FROM"), + /**/; + + private final FunctionName functionName; + private final String operator; + + private static final Map OPERATOR_BY_FUNCTION_NAME = Stream.of(values()) + .collect(toImmutableMap(ComparisonOperator::getFunctionName, identity())); + + ComparisonOperator(FunctionName functionName, String operator) + { + this.functionName = requireNonNull(functionName, "functionName is null"); + this.operator = requireNonNull(operator, "operator is null"); + } + + public FunctionName getFunctionName() + { + return functionName; + } + + public String getOperator() + { + return operator; + } + + public static ComparisonOperator forFunctionName(FunctionName functionName) + { + return verifyNotNull(OPERATOR_BY_FUNCTION_NAME.get(functionName), "Function name not recognized: %s", functionName); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java index 10db5e84ab6c..2231ca86e33b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteComparison.java @@ -23,14 +23,10 @@ import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.FunctionName; -import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Stream; import static com.google.common.base.Verify.verify; -import static com.google.common.base.Verify.verifyNotNull; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.matching.Capture.newCapture; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; @@ -39,17 +35,8 @@ import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; -import static io.trino.spi.expression.StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; -import static io.trino.spi.expression.StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; import static java.lang.String.format; -import static java.util.Objects.requireNonNull; -import static java.util.function.Function.identity; public class RewriteComparison implements ConnectorExpressionRule @@ -57,45 +44,6 @@ public class RewriteComparison private static final Capture LEFT = newCapture(); private static final Capture RIGHT = newCapture(); - public enum ComparisonOperator - { - EQUAL(EQUAL_OPERATOR_FUNCTION_NAME, "="), - NOT_EQUAL(NOT_EQUAL_OPERATOR_FUNCTION_NAME, "<>"), - LESS_THAN(LESS_THAN_OPERATOR_FUNCTION_NAME, "<"), - LESS_THAN_OR_EQUAL(LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, "<="), - GREATER_THAN(GREATER_THAN_OPERATOR_FUNCTION_NAME, ">"), - GREATER_THAN_OR_EQUAL(GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ">="), - IS_DISTINCT_FROM(IS_DISTINCT_FROM_OPERATOR_FUNCTION_NAME, "IS DISTINCT FROM"), - /**/; - - private final FunctionName functionName; - private final String operator; - - private static final Map OPERATOR_BY_FUNCTION_NAME = Stream.of(values()) - .collect(toImmutableMap(ComparisonOperator::getFunctionName, identity())); - - ComparisonOperator(FunctionName functionName, String operator) - { - this.functionName = requireNonNull(functionName, "functionName is null"); - this.operator = requireNonNull(operator, "operator is null"); - } - - private FunctionName getFunctionName() - { - return functionName; - } - - private String getOperator() - { - return operator; - } - - private static ComparisonOperator forFunctionName(FunctionName functionName) - { - return verifyNotNull(OPERATOR_BY_FUNCTION_NAME.get(functionName), "Function name not recognized: %s", functionName); - } - } - private final Pattern pattern; public RewriteComparison(Set enabledOperators) diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestRewriteComparison.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestRewriteComparison.java index 80e70a2e6007..57504cb16b75 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestRewriteComparison.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/expression/TestRewriteComparison.java @@ -26,7 +26,7 @@ public class TestRewriteComparison public void testOperatorEnumsInSync() { assertThat( - Stream.of(RewriteComparison.ComparisonOperator.values()) + Stream.of(ComparisonOperator.values()) .map(Enum::name)) .containsExactlyInAnyOrder( Stream.of(ComparisonExpression.Operator.values()) diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index 9309cd5f259c..2b0c5317fb0a 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -45,6 +45,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; @@ -163,7 +164,7 @@ public IgniteClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() - .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .addStandardRules(this::quoted) .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") .map("$not($is_null(value))").to("value IS NOT NULL") diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java index 9a381e6b1043..e087630842c7 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java @@ -39,6 +39,7 @@ import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; @@ -246,7 +247,7 @@ public PhoenixClient(PhoenixConfig config, ConnectionFactory connectionFactory, getConnectionProperties(config).forEach((k, v) -> configuration.set((String) k, (String) v)); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index affd261a925b..3b4a43827c83 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -66,6 +66,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; +import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; @@ -306,7 +307,7 @@ public PostgreSqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) // TODO allow all comparison operators for numeric types - .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .map("$add(left: integer_type, right: integer_type)").to("left + right") diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index c1cdd1077027..fcb974b93988 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -60,6 +60,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; import io.trino.plugin.jdbc.aggregation.ImplementMinMax; import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.expression.RewriteComparison; @@ -300,7 +301,7 @@ public SqlServerClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - .add(new RewriteComparison(ImmutableSet.of(RewriteComparison.ComparisonOperator.EQUAL, RewriteComparison.ComparisonOperator.NOT_EQUAL))) + .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .map("$add(left: integer_type, right: integer_type)").to("left + right") From 1db64a7ac81ae6b217e7312db443574220b80147 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 1 Dec 2023 15:12:14 +0100 Subject: [PATCH 5/6] Implement complex join pushdown in JDBC connectors Implement non-deprecated `ConnectorMetadata.applyJoin` overload in `DefaultJdbcMetadata`. Thew old implementation is retained as a safety valve. The new implementation is not limited to the `List` model, so allows pushdown of joins involving more complex expressions, such as arithmetics. The `BaseJdbcClient.implementJoin` and `QueryBuilder.prepareJoinQuery` methods logically changed, but the old implementation is left as the fallback. These methods were extension points, so the old implementations are renamed to ensure implementors are updated. For example, if an implementation was overriding `BaseJdbcClient.implementJoin` it most likely wants to override the new `implementJoin` method as well, and this is reminded about by rename of the old method. --- .../io/trino/sql/query/QueryAssertions.java | 4 + .../base/expression/ConnectorExpressions.java | 4 + .../io/trino/plugin/jdbc/BaseJdbcClient.java | 18 ++-- .../trino/plugin/jdbc/CachingJdbcClient.java | 8 +- .../plugin/jdbc/DefaultJdbcMetadata.java | 30 +++--- .../plugin/jdbc/DefaultQueryBuilder.java | 36 +++---- .../plugin/jdbc/ForwardingJdbcClient.java | 8 +- .../java/io/trino/plugin/jdbc/JdbcClient.java | 6 +- .../io/trino/plugin/jdbc/QueryBuilder.java | 6 +- .../RewriteCaseSensitiveComparison.java | 94 +++++++++++++++++++ .../jdbc/jmx/StatisticsAwareJdbcClient.java | 9 +- .../plugin/jdbc/BaseJdbcConnectorTest.java | 20 +++- .../jdbc/TestDefaultJdbcQueryBuilder.java | 19 ++-- .../io/trino/plugin/ignite/IgniteClient.java | 15 ++- .../trino/plugin/mariadb/MariaDbClient.java | 27 +++++- .../io/trino/plugin/mysql/MySqlClient.java | 17 +++- plugin/trino-oracle/pom.xml | 5 + .../io/trino/plugin/oracle/OracleClient.java | 15 +++ .../oracle/RewriteStringComparison.java | 93 ++++++++++++++++++ .../trino/plugin/phoenix5/PhoenixClient.java | 17 ++++ .../plugin/postgresql/PostgreSqlClient.java | 35 +++++-- .../postgresql/TestPostgreSqlClient.java | 7 +- .../trino/plugin/redshift/RedshiftClient.java | 8 +- .../plugin/singlestore/SingleStoreClient.java | 32 ++++++- .../plugin/sqlserver/SqlServerClient.java | 19 ++-- 25 files changed, 441 insertions(+), 111 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCaseSensitiveComparison.java create mode 100644 plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteStringComparison.java diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index 75d557980943..72705779f319 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -571,7 +571,10 @@ public QueryAssert isNotFullyPushedDown(PlanMatchPattern retainedSubplan) /** * Verifies join query is not fully pushed down by containing JOIN node. + * + * @deprecated because the method is not tested in BaseQueryAssertionsTest yet */ + @Deprecated @CanIgnoreReturnValue public QueryAssert joinIsNotFullyPushedDown() { @@ -580,6 +583,7 @@ public QueryAssert joinIsNotFullyPushedDown() .whereIsInstanceOfAny(JoinNode.class) .findFirst() .isEmpty()) { + // TODO show then plan when assertions fails (like hasPlan()) and add negative test coverage in BaseQueryAssertionsTest throw new IllegalStateException("Join node should be present in explain plan, when pushdown is not applied"); } }); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java index 3971a2995cc2..71d587247f6f 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/expression/ConnectorExpressions.java @@ -53,6 +53,10 @@ public static List extractConjuncts(ConnectorExpression exp private static void extractConjuncts(ConnectorExpression expression, ImmutableList.Builder resultBuilder) { + if (expression.equals(TRUE)) { + // Skip useless conjuncts. + return; + } if (expression instanceof Call call) { if (AND_FUNCTION_NAME.equals(call.getFunctionName())) { for (ConnectorExpression argument : call.getArguments()) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index a8150d624563..989317dae45d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -527,18 +527,12 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { - for (JdbcJoinCondition joinCondition : joinConditions) { - if (!isSupportedJoinCondition(session, joinCondition)) { - return Optional.empty(); - } - } - try (Connection connection = this.connectionFactory.openConnection(session)) { return Optional.of(queryBuilder.prepareJoinQuery( this, @@ -546,10 +540,10 @@ public Optional implementJoin( connection, joinType, leftSource, + leftProjections, rightSource, - joinConditions, - leftAssignments, - rightAssignments)); + rightProjections, + joinConditions)); } catch (SQLException e) { throw new TrinoException(JDBC_ERROR, e); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 032d9895ca1b..f473c94d96b9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -284,13 +284,13 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { - return delegate.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + return delegate.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index cee54f86096c..607d813b49ef 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -40,7 +40,6 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -73,6 +72,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; import java.util.OptionalLong; @@ -442,7 +442,7 @@ public Optional> applyJoin( JoinType joinType, ConnectorTableHandle left, ConnectorTableHandle right, - List joinConditions, + ConnectorExpression joinCondition, Map leftAssignments, Map rightAssignments, JoinStatistics statistics) @@ -478,26 +478,32 @@ public Optional> applyJoin( } Map newRightColumns = newRightColumnsBuilder.buildOrThrow(); - ImmutableList.Builder jdbcJoinConditions = ImmutableList.builder(); - for (JoinCondition joinCondition : joinConditions) { - Optional leftColumn = getVariableColumnHandle(leftAssignments, joinCondition.getLeftExpression()); - Optional rightColumn = getVariableColumnHandle(rightAssignments, joinCondition.getRightExpression()); - if (leftColumn.isEmpty() || rightColumn.isEmpty()) { + Map assignments = ImmutableMap.builder() + .putAll(leftAssignments.entrySet().stream() + .collect(toImmutableMap(Entry::getKey, entry -> newLeftColumns.get((JdbcColumnHandle) entry.getValue())))) + .putAll(rightAssignments.entrySet().stream() + .collect(toImmutableMap(Entry::getKey, entry -> newRightColumns.get((JdbcColumnHandle) entry.getValue())))) + .buildOrThrow(); + + ImmutableList.Builder joinConditions = ImmutableList.builder(); + for (ConnectorExpression conjunct : extractConjuncts(joinCondition)) { + Optional converted = jdbcClient.convertPredicate(session, conjunct, assignments); + if (converted.isEmpty()) { return Optional.empty(); } - jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get())); + joinConditions.add(converted.get()); } Optional joinQuery = jdbcClient.implementJoin( session, joinType, asPreparedQuery(leftHandle), + newLeftColumns.entrySet().stream() + .collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())), asPreparedQuery(rightHandle), - jdbcJoinConditions.build(), newRightColumns.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())), - newLeftColumns.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())), + .collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())), + joinConditions.build(), statistics); if (joinQuery.isEmpty()) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index dfd3b0674329..7d3a4f5f0524 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -117,35 +117,32 @@ public PreparedQuery prepareJoinQuery( Connection connection, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map leftAssignments, - Map rightAssignments) + Map rightProjections, + List joinConditions) { - // Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here. - verify(!leftAssignments.isEmpty(), "leftAssignments is empty"); - verify(!rightAssignments.isEmpty(), "rightAssignments is empty"); // Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here verify(!joinConditions.isEmpty(), "joinConditions is empty"); - String leftRelationAlias = "l"; - String rightRelationAlias = "r"; - String query = format( - "SELECT %s, %s FROM (%s) %s %s (%s) %s ON %s", - formatAssignments(client, leftRelationAlias, leftAssignments), - formatAssignments(client, rightRelationAlias, rightAssignments), + // The subquery aliases (`l` and `r`) are needed by some databases, but are not needed for expressions + // The joinConditions and output columns are aliased to use unique names. + "SELECT * FROM (SELECT %s FROM (%s) l) l %s (SELECT %s FROM (%s) r) r ON %s", + formatProjections(client, leftProjections), leftSource.getQuery(), - leftRelationAlias, formatJoinType(joinType), + formatProjections(client, rightProjections), rightSource.getQuery(), - rightRelationAlias, joinConditions.stream() - .map(condition -> formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition)) - .collect(joining(" AND "))); + .map(ParameterizedExpression::expression) + .collect(joining(") AND (", "(", ")"))); List parameters = ImmutableList.builder() .addAll(leftSource.getParameters()) .addAll(rightSource.getParameters()) + .addAll(joinConditions.stream() + .flatMap(expression -> expression.parameters().stream()) + .iterator()) .build(); return new PreparedQuery(query, parameters); } @@ -296,6 +293,13 @@ protected String buildJoinColumn(JdbcClient client, JdbcColumnHandle columnHandl return client.quoted(columnHandle.getColumnName()); } + protected String formatProjections(JdbcClient client, Map projections) + { + return projections.entrySet().stream() + .map(entry -> format("%s AS %s", client.quoted(entry.getKey().getColumnName()), client.quoted(entry.getValue()))) + .collect(joining(", ")); + } + protected String formatAssignments(JdbcClient client, String relationAlias, Map assignments) { return assignments.entrySet().stream() diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 2f784634a44b..f2ac9981f9ef 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -211,13 +211,13 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { - return delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + return delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } @Override diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index 38e6f7e80114..a52f100c78b1 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -127,10 +127,10 @@ Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics); boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index 52cb58e80cb5..a6dc6f53d985 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -47,10 +47,10 @@ PreparedQuery prepareJoinQuery( Connection connection, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map leftAssignments, - Map rightAssignments); + Map rightProjections, + List joinConditions); PreparedQuery prepareDeleteQuery( JdbcClient client, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCaseSensitiveComparison.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCaseSensitiveComparison.java new file mode 100644 index 000000000000..2ff0c1c8ad70 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCaseSensitiveComparison.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.VarcharType; + +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.jdbc.CaseSensitivity.CASE_SENSITIVE; +import static io.trino.spi.type.BooleanType.BOOLEAN; + +public class RewriteCaseSensitiveComparison + implements ConnectorExpressionRule +{ + private static final Capture LEFT = newCapture(); + private static final Capture RIGHT = newCapture(); + + private final Pattern pattern; + + public RewriteCaseSensitiveComparison(Set enabledOperators) + { + Set functionNames = enabledOperators.stream() + .map(ComparisonOperator::getFunctionName) + .collect(toImmutableSet()); + + pattern = call() + .with(type().equalTo(BOOLEAN)) + .with(functionName().matching(functionNames::contains)) + .with(argumentCount().equalTo(2)) + .with(argument(0).matching(variable().with(type().matching(VarcharType.class::isInstance)).capturedAs(LEFT))) + .with(argument(1).matching(variable().with(type().matching(VarcharType.class::isInstance)).capturedAs(RIGHT))); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public Optional rewrite(Call expression, Captures captures, RewriteContext context) + { + ComparisonOperator comparison = ComparisonOperator.forFunctionName(expression.getFunctionName()); + Variable firstArgument = captures.get(LEFT); + Variable secondArgument = captures.get(RIGHT); + + if (!isCaseSensitive(firstArgument, context) || !isCaseSensitive(secondArgument, context)) { + return Optional.empty(); + } + return context.defaultRewrite(firstArgument).flatMap(first -> + context.defaultRewrite(secondArgument).map(second -> + new ParameterizedExpression( + "(%s) %s (%s)".formatted(first.expression(), comparison.getOperator(), second.expression()), + ImmutableList.builder() + .addAll(first.parameters()) + .addAll(second.parameters()) + .build()))); + } + + private static boolean isCaseSensitive(Variable variable, RewriteContext context) + { + return ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle().getCaseSensitivity().equals(Optional.of(CASE_SENSITIVE)); + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index d07bcfa5178c..a19087465750 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -17,7 +17,6 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; -import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcProcedureHandle; import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; @@ -231,13 +230,13 @@ public CallableStatement buildProcedure(ConnectorSession session, Connection con public Optional implementJoin(ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { - return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } @Override diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 40eb2f8da22e..1cf08c54b3fd 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1325,6 +1325,20 @@ public void testJoinPushdown() } } + @Test + public void testComplexJoinPushdown() + { + String catalog = getSession().getCatalog().orElseThrow(); + Session session = joinPushdownEnabled(getSession()); + String query = "SELECT n.name, o.orderstatus FROM nation n JOIN orders o ON n.regionkey = o.orderkey AND n.nationkey + o.custkey - 3 = 0"; + + // The join can be pushed down + assertJoinConditionallyPushedDown( + session, + query, + hasBehavior(SUPPORTS_JOIN_PUSHDOWN) && hasBehavior(SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN)); + } + @Test public void testExplainAnalyzePhysicalReadWallTime() { @@ -1388,8 +1402,7 @@ protected void assertConditionallyOrderedPushedDown( protected boolean expectJoinPushdown(String operator) { if ("IS NOT DISTINCT FROM".equals(operator)) { - // TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM - return false; + return hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM); } return switch (toJoinConditionOperator(operator)) { case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> true; @@ -1406,8 +1419,7 @@ protected boolean expectJoinPushdownOnInequalityOperator(JoinOperator joinOperat private boolean expectVarcharJoinPushdown(String operator) { if ("IS NOT DISTINCT FROM".equals(operator)) { - // TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM - return false; + return hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM) && hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY); } return switch (toJoinConditionOperator(operator)) { case EQUAL, NOT_EQUAL -> hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY); diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java index c58a9b5633d2..a407d624d5d8 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java @@ -22,7 +22,6 @@ import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.Domain; @@ -517,17 +516,17 @@ public void testBuildJoinSql() connection, JoinType.INNER, new PreparedQuery("SELECT * FROM \"test_table\"", List.of()), + ImmutableMap.of(columns.get(2), "name1", columns.get(7), "lcol7"), new PreparedQuery("SELECT * FROM \"test_table\"", List.of()), - List.of(new JdbcJoinCondition(columns.get(7), JoinCondition.Operator.EQUAL, columns.get(8))), - Map.of(columns.get(2), "name1"), - Map.of(columns.get(3), "name2")); + ImmutableMap.of(columns.get(3), "name2", columns.get(8), "rcol8"), + List.of(new ParameterizedExpression("\"lcol7\" = \"rcol8\"", List.of()))); try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(jdbcClient, SESSION, connection, preparedQuery, Optional.empty())) { - assertThat(preparedQuery.getQuery()).isEqualTo("" + - "SELECT l.\"col_2\" AS \"name1\", r.\"col_3\" AS \"name2\" FROM " + - "(SELECT * FROM \"test_table\") l " + - "INNER JOIN " + - "(SELECT * FROM \"test_table\") r " + - "ON l.\"col_7\" = r.\"col_8\""); + assertThat(preparedQuery.getQuery()).isEqualTo(""" + SELECT * FROM \ + (SELECT "col_2" AS "name1", "col_7" AS "lcol7" FROM (SELECT * FROM "test_table") l) l \ + INNER JOIN \ + (SELECT "col_3" AS "name2", "col_8" AS "rcol8" FROM (SELECT * FROM "test_table") r) r \ + ON ("lcol7" = "rcol8")"""); long count = 0; try (ResultSet resultSet = preparedStatement.executeQuery()) { while (resultSet.next()) { diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index 2b0c5317fb0a..7af9b61915fc 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -166,6 +166,13 @@ public IgniteClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .addStandardRules(this::quoted) + .map("$equal(left, right)").to("left = right") + .map("$not_equal(left, right)").to("left <> right") + .map("$is_distinct_from(left, right)").to("left IS DISTINCT FROM right") + .map("$less_than(left, right)").to("left < right") + .map("$less_than_or_equal(left, right)").to("left <= right") + .map("$greater_than(left, right)").to("left > right") + .map("$greater_than_or_equal(left, right)").to("left >= right") .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") @@ -574,10 +581,10 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { // Ignite does not support FULL JOIN @@ -585,7 +592,7 @@ public Optional implementJoin( return Optional.empty(); } - return super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } @Override diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java index af3a6c0ecca8..2f8d2ff527de 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java @@ -59,6 +59,7 @@ import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.statistics.ColumnStatistics; import io.trino.spi.statistics.Estimate; import io.trino.spi.statistics.TableStatistics; @@ -173,6 +174,7 @@ public class MariaDbClient private static final int PARSE_ERROR = 1064; private final boolean statisticsEnabled; + private final ConnectorExpressionRewriter connectorExpressionRewriter; private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject @@ -187,8 +189,17 @@ public MariaDbClient( super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) + // No "real" on the list; pushdown on REAL is disabled also in toColumnMapping + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double")) + .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") + .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") + // .map("$is_distinct_from(left: numeric_type, right: numeric_type)").to("left IS DISTINCT FROM right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") .build(); this.statisticsEnabled = statisticsConfig.isEnabled(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( @@ -222,6 +233,12 @@ public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHa return preventTextualTypeAggregationPushdown(groupingSets); } + @Override + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + { + return connectorExpressionRewriter.rewrite(session, expression, assignments); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); @@ -622,17 +639,17 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { if (joinType == JoinType.FULL_OUTER) { // Not supported in MariaDB return Optional.empty(); } - return super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } @Override diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index a53870ed76ce..cece995ce02b 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -270,6 +270,15 @@ public MySqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) + // No "real" on the list; pushdown on REAL is disabled also in toColumnMapping + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double")) + .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") + .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") + // .map("$is_distinct_from(left: numeric_type, right: numeric_type)").to("left IS DISTINCT FROM right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") .add(new RewriteLikeWithCaseSensitivity()) .add(new RewriteLikeEscapeWithCaseSensitivity()) .build(); @@ -1005,10 +1014,10 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { if (joinType == JoinType.FULL_OUTER) { @@ -1021,7 +1030,7 @@ public Optional implementJoin( leftSource, rightSource, statistics, - () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } @Override diff --git a/plugin/trino-oracle/pom.xml b/plugin/trino-oracle/pom.xml index dc28a30560dd..614570a10012 100644 --- a/plugin/trino-oracle/pom.xml +++ b/plugin/trino-oracle/pom.xml @@ -58,6 +58,11 @@ trino-base-jdbc + + io.trino + trino-matching + + io.trino trino-plugin-toolkit diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index 47837cff4e3e..e62131771408 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -62,6 +62,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.JoinCondition; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -231,6 +232,14 @@ public OracleClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) + .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") + .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .add(new RewriteStringComparison()) .build(); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(TRINO_BIGINT_TYPE, Optional.of("NUMBER"), Optional.of(0), Optional.of(0), Optional.empty(), Optional.empty()); @@ -538,6 +547,12 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + { + return connectorExpressionRewriter.rewrite(session, expression, assignments); + } + private static Optional toTypeHandle(DecimalType decimalType) { return Optional.of(new JdbcTypeHandle(OracleTypes.NUMBER, Optional.of("NUMBER"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteStringComparison.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteStringComparison.java new file mode 100644 index 000000000000..1e3f8c73bf7d --- /dev/null +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/RewriteStringComparison.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.oracle; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ComparisonOperator; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.CharType; +import io.trino.spi.type.VarcharType; +import oracle.jdbc.OracleTypes; + +import java.util.Optional; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.spi.type.BooleanType.BOOLEAN; + +public class RewriteStringComparison + implements ConnectorExpressionRule +{ + private static final Capture FIRST_ARGUMENT = newCapture(); + private static final Capture SECOND_ARGUMENT = newCapture(); + private static final Pattern PATTERN = call() + .with(type().equalTo(BOOLEAN)) + .with(functionName().matching(Stream.of(ComparisonOperator.values()) + .filter(comparison -> comparison != ComparisonOperator.IS_DISTINCT_FROM) + .map(ComparisonOperator::getFunctionName) + .collect(toImmutableSet()) + ::contains)) + .with(argumentCount().equalTo(2)) + .with(argument(0).matching(variable().with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)).capturedAs(FIRST_ARGUMENT))) + .with(argument(1).matching(variable().with(type().matching(type -> type instanceof CharType || type instanceof VarcharType)).capturedAs(SECOND_ARGUMENT))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(Call expression, Captures captures, RewriteContext context) + { + ComparisonOperator comparison = ComparisonOperator.forFunctionName(expression.getFunctionName()); + Variable firstArgument = captures.get(FIRST_ARGUMENT); + Variable secondArgument = captures.get(SECOND_ARGUMENT); + + if (isClob(firstArgument, context) || isClob(secondArgument, context)) { + return Optional.empty(); + } + return context.defaultRewrite(firstArgument).flatMap(first -> + context.defaultRewrite(secondArgument).map(second -> + new ParameterizedExpression( + "(%s) %s (%s)".formatted(first.expression(), comparison.getOperator(), second.expression()), + ImmutableList.builder() + .addAll(first.parameters()) + .addAll(second.parameters()) + .build()))); + } + + private static boolean isClob(Variable variable, RewriteContext context) + { + return switch (((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle().getJdbcType()) { + case OracleTypes.CLOB, OracleTypes.NCLOB -> true; + default -> false; + }; + } +} diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java index e087630842c7..7bd91251c851 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java @@ -51,6 +51,8 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaNotFoundException; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.expression.ConnectorExpression; @@ -264,6 +266,21 @@ public Optional convertPredicate(ConnectorSession sessi return connectorExpressionRewriter.rewrite(session, expression, assignments); } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + Map leftProjections, + PreparedQuery rightSource, + Map rightProjections, + List joinConditions, + JoinStatistics statistics) + { + // Joins are currently not supported + return Optional.empty(); + } + public Connection getConnection(ConnectorSession session) throws SQLException { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 3b4a43827c83..976f09c08285 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -66,10 +66,8 @@ import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; -import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; -import io.trino.plugin.jdbc.expression.RewriteComparison; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; @@ -277,6 +275,7 @@ public class PostgreSqlClient private final List tableTypes; private final boolean statisticsEnabled; private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final ConnectorExpressionRewriter connectorExpressionRewriterWithCollate; private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject @@ -304,12 +303,18 @@ public PostgreSqlClient( this.statisticsEnabled = statisticsConfig.isEnabled(); - this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + JdbcConnectorExpressionRewriterBuilder connectorExpressionRewriterBuilder = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - // TODO allow all comparison operators for numeric types - .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) + .map("$equal(left, right)").to("left = right") + .map("$not_equal(left, right)").to("left <> right") + .map("$is_distinct_from(left, right)").to("left IS DISTINCT FROM right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") .map("$multiply(left: integer_type, right: integer_type)").to("left * right") @@ -321,7 +326,14 @@ public PostgreSqlClient( .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") .map("$is_null(value)").to("value IS NULL") - .map("$nullif(first, second)").to("NULLIF(first, second)") + .map("$nullif(first, second)").to("NULLIF(first, second)"); + this.connectorExpressionRewriter = connectorExpressionRewriterBuilder.build(); + this.connectorExpressionRewriterWithCollate = connectorExpressionRewriterBuilder + .withTypeClass("collatable_type", ImmutableSet.of("char", "varchar")) + .map("$less_than(left: collatable_type, right: collatable_type)").to("left < right COLLATE \"C\"") + .map("$less_than_or_equal(left: collatable_type, right: collatable_type)").to("left <= right COLLATE \"C\"") + .map("$greater_than(left: collatable_type, right: collatable_type)").to("left > right COLLATE \"C\"") + .map("$greater_than_or_equal(left: collatable_type, right: collatable_type)").to("left >= right COLLATE \"C\"") .build(); JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); @@ -786,6 +798,9 @@ public boolean supportsAggregationPushdown(ConnectorSession session, JdbcTableHa @Override public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) { + if (isEnableStringPushdownWithCollate(session)) { + return connectorExpressionRewriterWithCollate.rewrite(session, expression, assignments); + } return connectorExpressionRewriter.rewrite(session, expression, assignments); } @@ -1040,10 +1055,10 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { if (joinType == JoinType.FULL_OUTER) { @@ -1056,7 +1071,7 @@ public Optional implementJoin( leftSource, rightSource, statistics, - () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } @Override diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index fd47c631f77d..90692c075be7 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.postgresql; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.plugin.base.mapping.DefaultIdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcConfig; @@ -33,6 +34,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.expression.Variable; +import io.trino.spi.session.PropertyMetadata; import io.trino.spi.type.Type; import io.trino.sql.planner.ConnectorExpressionTranslator; import io.trino.sql.planner.LiteralEncoder; @@ -118,7 +120,10 @@ public class TestPostgreSqlClient private static final ConnectorSession SESSION = TestingConnectorSession .builder() - .setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) + .setPropertyMetadata(ImmutableList.>builder() + .addAll(new JdbcMetadataSessionProperties(new JdbcMetadataConfig(), Optional.empty()).getSessionProperties()) + .addAll(new PostgreSqlSessionProperties(new PostgreSqlConfig()).getSessionProperties()) + .build()) .build(); @Test diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index ba04e2693f02..d39389fc0a7c 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -407,10 +407,10 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon public Optional implementJoin(ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { if (joinType == JoinType.FULL_OUTER) { @@ -423,7 +423,7 @@ public Optional implementJoin(ConnectorSession session, leftSource, rightSource, statistics, - () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } @Override diff --git a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java index b715c1b293d8..a080363ff4b7 100644 --- a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java +++ b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.mapping.IdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; @@ -31,6 +32,8 @@ import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.RemoteTableName; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -40,6 +43,7 @@ import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -151,6 +155,7 @@ public class SingleStoreClient private static final Pattern UNSIGNED_TYPE_REGEX = Pattern.compile("(?i).*unsigned$"); private final Type jsonType; + private final ConnectorExpressionRewriter connectorExpressionRewriter; @Inject public SingleStoreClient( @@ -183,6 +188,19 @@ protected SingleStoreClient( super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, supportsRetries); requireNonNull(typeManager, "typeManager is null"); this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); + + this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + .addStandardRules(this::quoted) + // No "real" on the list; pushdown on REAL is disabled also in toColumnMapping + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "double")) + .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") + .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") + // .map("$is_distinct_from(left: numeric_type, right: numeric_type)").to("left IS DISTINCT FROM right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .build(); } @Override @@ -553,22 +571,28 @@ public boolean isTopNGuaranteed(ConnectorSession session) return true; } + @Override + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + { + return connectorExpressionRewriter.rewrite(session, expression, assignments); + } + @Override public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { if (joinType == JoinType.FULL_OUTER) { // Not supported in SingleStore return Optional.empty(); } - return super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } @Override diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index fcb974b93988..2f1fac35ff77 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -63,7 +63,7 @@ import io.trino.plugin.jdbc.expression.ComparisonOperator; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; -import io.trino.plugin.jdbc.expression.RewriteComparison; +import io.trino.plugin.jdbc.expression.RewriteCaseSensitiveComparison; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.TrinoException; @@ -301,9 +301,16 @@ public SqlServerClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) - .add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .add(new RewriteIn()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) + .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) + .map("$equal(left: numeric_type, right: numeric_type)").to("left = right") + .map("$not_equal(left: numeric_type, right: numeric_type)").to("left <> right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .add(new RewriteCaseSensitiveComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL))) .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") .map("$multiply(left: integer_type, right: integer_type)").to("left * right") @@ -882,10 +889,10 @@ public Optional implementJoin( ConnectorSession session, JoinType joinType, PreparedQuery leftSource, + Map leftProjections, PreparedQuery rightSource, - List joinConditions, - Map rightAssignments, - Map leftAssignments, + Map rightProjections, + List joinConditions, JoinStatistics statistics) { return implementJoinCostAware( @@ -894,7 +901,7 @@ public Optional implementJoin( leftSource, rightSource, statistics, - () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } private LongWriteFunction sqlServerTimeWriteFunction(int precision) From 4450dde87cd61b1ee4e5d6c75d12b5ea9ab5a6cf Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 8 Dec 2023 15:43:07 +0100 Subject: [PATCH 6/6] Re-add expression-less join pushdown as fallback Restore older JDBC join pushdown implementation not based on `ConnectorExpression` as a fallback. This comes as a separate commit so that the introduction of `ConnectorExpression`-based join pushdown can be seen (e.g. reviewed) as a _change_, not as an _addition_. --- .../io/trino/plugin/jdbc/BaseJdbcClient.java | 35 ++++++ .../trino/plugin/jdbc/CachingJdbcClient.java | 14 +++ .../plugin/jdbc/DefaultJdbcMetadata.java | 110 ++++++++++++++++++ .../plugin/jdbc/DefaultQueryBuilder.java | 40 +++++++ .../plugin/jdbc/ForwardingJdbcClient.java | 14 +++ .../java/io/trino/plugin/jdbc/JdbcClient.java | 11 ++ .../trino/plugin/jdbc/JdbcMetadataConfig.java | 14 +++ .../jdbc/JdbcMetadataSessionProperties.java | 11 ++ .../io/trino/plugin/jdbc/QueryBuilder.java | 11 ++ .../jdbc/jmx/StatisticsAwareJdbcClient.java | 14 +++ .../plugin/jdbc/BaseJdbcConnectorTest.java | 9 ++ .../jdbc/TestDefaultJdbcQueryBuilder.java | 34 ++++++ .../plugin/jdbc/TestJdbcMetadataConfig.java | 3 + .../io/trino/plugin/ignite/IgniteClient.java | 19 +++ .../trino/plugin/mariadb/MariaDbClient.java | 18 +++ .../io/trino/plugin/mysql/MySqlClient.java | 24 ++++ .../plugin/postgresql/PostgreSqlClient.java | 24 ++++ .../trino/plugin/redshift/RedshiftClient.java | 23 ++++ .../plugin/singlestore/SingleStoreClient.java | 18 +++ .../plugin/sqlserver/SqlServerClient.java | 20 ++++ 20 files changed, 466 insertions(+) diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 989317dae45d..42d6731d52cc 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -550,6 +550,41 @@ public Optional implementJoin( } } + @Deprecated + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + for (JdbcJoinCondition joinCondition : joinConditions) { + if (!isSupportedJoinCondition(session, joinCondition)) { + return Optional.empty(); + } + } + + try (Connection connection = this.connectionFactory.openConnection(session)) { + return Optional.of(queryBuilder.legacyPrepareJoinQuery( + this, + session, + connection, + joinType, + leftSource, + rightSource, + joinConditions, + leftAssignments, + rightAssignments)); + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { return false; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index f473c94d96b9..ecea7b964a4e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -293,6 +293,20 @@ public Optional implementJoin( return delegate.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return delegate.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + } + @Override public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 607d813b49ef..7d0d7f6d6b08 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -40,6 +40,7 @@ import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.JoinApplicationResult; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinStatistics; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; @@ -94,6 +95,7 @@ import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isAggregationPushdownEnabled; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexExpressionPushdown; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexJoinPushdownEnabled; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled; import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert; @@ -447,6 +449,19 @@ public Optional> applyJoin( Map rightAssignments, JoinStatistics statistics) { + if (!isComplexJoinPushdownEnabled(session)) { + // Fallback to the old join pushdown code + return JdbcMetadata.super.applyJoin( + session, + joinType, + left, + right, + joinCondition, + leftAssignments, + rightAssignments, + statistics); + } + if (isTableHandleForProcedure(left) || isTableHandleForProcedure(right)) { return Optional.empty(); } @@ -536,6 +551,101 @@ public Optional> applyJoin( precalculateStatisticsForPushdown)); } + @Deprecated + @Override + public Optional> applyJoin( + ConnectorSession session, + JoinType joinType, + ConnectorTableHandle left, + ConnectorTableHandle right, + List joinConditions, + Map leftAssignments, + Map rightAssignments, + JoinStatistics statistics) + { + if (isTableHandleForProcedure(left) || isTableHandleForProcedure(right)) { + return Optional.empty(); + } + + if (!isJoinPushdownEnabled(session)) { + return Optional.empty(); + } + + JdbcTableHandle leftHandle = flushAttributesAsQuery(session, (JdbcTableHandle) left); + JdbcTableHandle rightHandle = flushAttributesAsQuery(session, (JdbcTableHandle) right); + + if (!leftHandle.getAuthorization().equals(rightHandle.getAuthorization())) { + return Optional.empty(); + } + int nextSyntheticColumnId = max(leftHandle.getNextSyntheticColumnId(), rightHandle.getNextSyntheticColumnId()); + + ImmutableMap.Builder newLeftColumnsBuilder = ImmutableMap.builder(); + OptionalInt maxColumnNameLength = jdbcClient.getMaxColumnNameLength(session); + for (JdbcColumnHandle column : jdbcClient.getColumns(session, leftHandle)) { + newLeftColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength)); + nextSyntheticColumnId++; + } + Map newLeftColumns = newLeftColumnsBuilder.buildOrThrow(); + + ImmutableMap.Builder newRightColumnsBuilder = ImmutableMap.builder(); + for (JdbcColumnHandle column : jdbcClient.getColumns(session, rightHandle)) { + newRightColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength)); + nextSyntheticColumnId++; + } + Map newRightColumns = newRightColumnsBuilder.buildOrThrow(); + + ImmutableList.Builder jdbcJoinConditions = ImmutableList.builder(); + for (JoinCondition joinCondition : joinConditions) { + Optional leftColumn = getVariableColumnHandle(leftAssignments, joinCondition.getLeftExpression()); + Optional rightColumn = getVariableColumnHandle(rightAssignments, joinCondition.getRightExpression()); + if (leftColumn.isEmpty() || rightColumn.isEmpty()) { + return Optional.empty(); + } + jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get())); + } + + Optional joinQuery = jdbcClient.legacyImplementJoin( + session, + joinType, + asPreparedQuery(leftHandle), + asPreparedQuery(rightHandle), + jdbcJoinConditions.build(), + newRightColumns.entrySet().stream() + .collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())), + newLeftColumns.entrySet().stream() + .collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())), + statistics); + + if (joinQuery.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new JoinApplicationResult<>( + new JdbcTableHandle( + new JdbcQueryRelationHandle(joinQuery.get()), + TupleDomain.all(), + ImmutableList.of(), + Optional.empty(), + OptionalLong.empty(), + Optional.of( + ImmutableList.builder() + .addAll(newLeftColumns.values()) + .addAll(newRightColumns.values()) + .build()), + leftHandle.getAllReferencedTables().flatMap(leftReferencedTables -> + rightHandle.getAllReferencedTables().map(rightReferencedTables -> + ImmutableSet.builder() + .addAll(leftReferencedTables) + .addAll(rightReferencedTables) + .build())), + nextSyntheticColumnId, + leftHandle.getAuthorization(), + leftHandle.getUpdateAssignments()), + ImmutableMap.copyOf(newLeftColumns), + ImmutableMap.copyOf(newRightColumns), + precalculateStatisticsForPushdown)); + } + @VisibleForTesting static JdbcColumnHandle createSyntheticJoinProjectionColumn(JdbcColumnHandle column, int nextSyntheticColumnId, OptionalInt optionalMaxColumnNameLength) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index 7d3a4f5f0524..bd06719f9ce1 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -147,6 +147,46 @@ public PreparedQuery prepareJoinQuery( return new PreparedQuery(query, parameters); } + @Override + public PreparedQuery legacyPrepareJoinQuery( + JdbcClient client, + ConnectorSession session, + Connection connection, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map leftAssignments, + Map rightAssignments) + { + // Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here. + verify(!leftAssignments.isEmpty(), "leftAssignments is empty"); + verify(!rightAssignments.isEmpty(), "rightAssignments is empty"); + // Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here + verify(!joinConditions.isEmpty(), "joinConditions is empty"); + + String leftRelationAlias = "l"; + String rightRelationAlias = "r"; + + String query = format( + "SELECT %s, %s FROM (%s) %s %s (%s) %s ON %s", + formatAssignments(client, leftRelationAlias, leftAssignments), + formatAssignments(client, rightRelationAlias, rightAssignments), + leftSource.getQuery(), + leftRelationAlias, + formatJoinType(joinType), + rightSource.getQuery(), + rightRelationAlias, + joinConditions.stream() + .map(condition -> formatJoinCondition(client, leftRelationAlias, rightRelationAlias, condition)) + .collect(joining(" AND "))); + List parameters = ImmutableList.builder() + .addAll(leftSource.getParameters()) + .addAll(rightSource.getParameters()) + .build(); + return new PreparedQuery(query, parameters); + } + @Override public PreparedQuery prepareDeleteQuery( JdbcClient client, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index f2ac9981f9ef..3f672896d56c 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -220,6 +220,20 @@ public Optional implementJoin( return delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return delegate().legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + } + @Override public JdbcOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index a52f100c78b1..cb2ff550a8e8 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -133,6 +133,17 @@ Optional implementJoin( List joinConditions, JoinStatistics statistics); + @Deprecated + Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics); + boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder); /** diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java index 039195481a86..d6896c157a1a 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java @@ -30,6 +30,7 @@ public class JdbcMetadataConfig * in terms of performance and money due to an increased network traffic. */ private boolean joinPushdownEnabled; + private boolean complexJoinPushdownEnabled = true; private boolean aggregationPushdownEnabled = true; private boolean topNPushdownEnabled = true; @@ -67,6 +68,19 @@ public JdbcMetadataConfig setJoinPushdownEnabled(boolean joinPushdownEnabled) return this; } + public boolean isComplexJoinPushdownEnabled() + { + return complexJoinPushdownEnabled; + } + + @Config("join-pushdown.with-expressions") + @ConfigDescription("Enable join pushdown with complex expressions") + public JdbcMetadataConfig setComplexJoinPushdownEnabled(boolean complexJoinPushdownEnabled) + { + this.complexJoinPushdownEnabled = complexJoinPushdownEnabled; + return this; + } + public boolean isAggregationPushdownEnabled() { return aggregationPushdownEnabled; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java index d4ae2a0b5b12..96476cce488e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java @@ -33,6 +33,7 @@ public class JdbcMetadataSessionProperties { public static final String COMPLEX_EXPRESSION_PUSHDOWN = "complex_expression_pushdown"; public static final String JOIN_PUSHDOWN_ENABLED = "join_pushdown_enabled"; + public static final String COMPLEX_JOIN_PUSHDOWN_ENABLED = "complex_join_pushdown_enabled"; public static final String AGGREGATION_PUSHDOWN_ENABLED = "aggregation_pushdown_enabled"; public static final String TOPN_PUSHDOWN_ENABLED = "topn_pushdown_enabled"; public static final String DOMAIN_COMPACTION_THRESHOLD = "domain_compaction_threshold"; @@ -54,6 +55,11 @@ public JdbcMetadataSessionProperties(JdbcMetadataConfig jdbcMetadataConfig, @Max "Enable join pushdown", jdbcMetadataConfig.isJoinPushdownEnabled(), false)) + .add(booleanProperty( + COMPLEX_JOIN_PUSHDOWN_ENABLED, + "Enable join pushdown with non-comparison expressions", + jdbcMetadataConfig.isComplexJoinPushdownEnabled(), + false)) .add(booleanProperty( AGGREGATION_PUSHDOWN_ENABLED, "Enable aggregation pushdown", @@ -89,6 +95,11 @@ public static boolean isJoinPushdownEnabled(ConnectorSession session) return session.getProperty(JOIN_PUSHDOWN_ENABLED, Boolean.class); } + public static boolean isComplexJoinPushdownEnabled(ConnectorSession session) + { + return session.getProperty(COMPLEX_JOIN_PUSHDOWN_ENABLED, Boolean.class); + } + public static boolean isAggregationPushdownEnabled(ConnectorSession session) { return session.getProperty(AGGREGATION_PUSHDOWN_ENABLED, Boolean.class); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index a6dc6f53d985..9c6fbf7c4ca9 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -52,6 +52,17 @@ PreparedQuery prepareJoinQuery( Map rightProjections, List joinConditions); + PreparedQuery legacyPrepareJoinQuery( + JdbcClient client, + ConnectorSession session, + Connection connection, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map leftAssignments, + Map rightAssignments); + PreparedQuery prepareDeleteQuery( JdbcClient client, ConnectorSession session, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index a19087465750..339cf17df0f5 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -17,6 +17,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcProcedureHandle; import io.trino.plugin.jdbc.JdbcProcedureHandle.ProcedureQuery; @@ -239,6 +240,19 @@ public Optional implementJoin(ConnectorSession session, return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } + @Override + public Optional legacyImplementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return stats.getImplementJoin().wrap(() -> delegate().legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override public Optional getTableComment(ResultSet resultSet) throws SQLException diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 1cf08c54b3fd..63ab5bebabdf 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -64,6 +64,7 @@ import static io.trino.SystemSessionProperties.MARK_DISTINCT_STRATEGY; import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_ENABLED; import static io.trino.plugin.jdbc.JdbcDynamicFilteringSessionProperties.DYNAMIC_FILTERING_WAIT_TIMEOUT; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.COMPLEX_JOIN_PUSHDOWN_ENABLED; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.DOMAIN_COMPACTION_THRESHOLD; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.JOIN_PUSHDOWN_ENABLED; import static io.trino.plugin.jdbc.JoinOperator.FULL_JOIN; @@ -1332,6 +1333,14 @@ public void testComplexJoinPushdown() Session session = joinPushdownEnabled(getSession()); String query = "SELECT n.name, o.orderstatus FROM nation n JOIN orders o ON n.regionkey = o.orderkey AND n.nationkey + o.custkey - 3 = 0"; + // The join cannot be pushed down without "complex join pushdown" + assertThat(query( + Session.builder(session) + .setCatalogSessionProperty(catalog, COMPLEX_JOIN_PUSHDOWN_ENABLED, "false") + .build(), + query)) + .joinIsNotFullyPushedDown(); + // The join can be pushed down assertJoinConditionallyPushedDown( session, diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java index a407d624d5d8..170df828714b 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestDefaultJdbcQueryBuilder.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.predicate.Domain; @@ -537,6 +538,39 @@ public void testBuildJoinSql() } } + @Test + public void testBuildJoinSqlLegacy() + throws SQLException + { + Connection connection = database.getConnection(); + + PreparedQuery preparedQuery = queryBuilder.legacyPrepareJoinQuery( + jdbcClient, + SESSION, + connection, + JoinType.INNER, + new PreparedQuery("SELECT * FROM \"test_table\"", List.of()), + new PreparedQuery("SELECT * FROM \"test_table\"", List.of()), + List.of(new JdbcJoinCondition(columns.get(7), JoinCondition.Operator.EQUAL, columns.get(8))), + Map.of(columns.get(2), "name1"), + Map.of(columns.get(3), "name2")); + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(jdbcClient, SESSION, connection, preparedQuery, Optional.empty())) { + assertThat(preparedQuery.getQuery()).isEqualTo("" + + "SELECT l.\"col_2\" AS \"name1\", r.\"col_3\" AS \"name2\" FROM " + + "(SELECT * FROM \"test_table\") l " + + "INNER JOIN " + + "(SELECT * FROM \"test_table\") r " + + "ON l.\"col_7\" = r.\"col_8\""); + long count = 0; + try (ResultSet resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + count++; + } + } + assertThat(count).isEqualTo(8); + } + } + @Test public void testBuildSqlWithLimit() throws SQLException diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java index ed09ca49e4ba..e86c88d7a3fc 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java @@ -30,6 +30,7 @@ public void testDefaults() assertRecordedDefaults(recordDefaults(JdbcMetadataConfig.class) .setComplexExpressionPushdownEnabled(true) .setJoinPushdownEnabled(false) + .setComplexJoinPushdownEnabled(true) .setAggregationPushdownEnabled(true) .setTopNPushdownEnabled(true) .setDomainCompactionThreshold(32)); @@ -41,6 +42,7 @@ public void testExplicitPropertyMappings() Map properties = ImmutableMap.builder() .put("complex-expression-pushdown.enabled", "false") .put("join-pushdown.enabled", "true") + .put("join-pushdown.with-expressions", "false") .put("aggregation-pushdown.enabled", "false") .put("domain-compaction-threshold", "42") .put("topn-pushdown.enabled", "false") @@ -49,6 +51,7 @@ public void testExplicitPropertyMappings() JdbcMetadataConfig expected = new JdbcMetadataConfig() .setComplexExpressionPushdownEnabled(false) .setJoinPushdownEnabled(true) + .setComplexJoinPushdownEnabled(false) .setAggregationPushdownEnabled(false) .setTopNPushdownEnabled(false) .setDomainCompactionThreshold(42); diff --git a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java index 7af9b61915fc..9fa2f946b2ad 100644 --- a/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java +++ b/plugin/trino-ignite/src/main/java/io/trino/plugin/ignite/IgniteClient.java @@ -595,6 +595,25 @@ public Optional implementJoin( return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + // Ignite does not support FULL JOIN + if (joinType == JoinType.FULL_OUTER) { + return Optional.empty(); + } + + return super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { diff --git a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java index 2f8d2ff527de..18b6e2ec02b5 100644 --- a/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java +++ b/plugin/trino-mariadb/src/main/java/io/trino/plugin/mariadb/MariaDbClient.java @@ -652,6 +652,24 @@ public Optional implementJoin( return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // Not supported in MariaDB + return Optional.empty(); + } + return super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index cece995ce02b..241ffd0cf051 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -1033,6 +1033,30 @@ public Optional implementJoin( () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // Not supported in MySQL + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 976f09c08285..ace69134eb2d 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -1074,6 +1074,30 @@ public Optional implementJoin( () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index d39389fc0a7c..c1cc81b72033 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -426,6 +426,29 @@ public Optional implementJoin(ConnectorSession session, () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } + @Override + public Optional legacyImplementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException diff --git a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java index a080363ff4b7..2f2afd7884cb 100644 --- a/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java +++ b/plugin/trino-singlestore/src/main/java/io/trino/plugin/singlestore/SingleStoreClient.java @@ -595,6 +595,24 @@ public Optional implementJoin( return super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // Not supported in SingleStore + return Optional.empty(); + } + return super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics); + } + @Override protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) { diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 2f1fac35ff77..a148e184ce72 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -904,6 +904,26 @@ public Optional implementJoin( () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); } + @Override + public Optional legacyImplementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + private LongWriteFunction sqlServerTimeWriteFunction(int precision) { return new LongWriteFunction()