From 2dca5ff77723f3b1307117bdaff78e6b45a693f3 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 11 Feb 2021 21:44:52 +0100 Subject: [PATCH] Implement Join pushdown for JDBC connectors --- .../iterative/rule/PushJoinIntoTableScan.java | 44 +-- .../rule/TestPushJoinIntoTableScan.java | 13 +- .../io/trino/sql/query/QueryAssertions.java | 21 +- .../io/trino/spi/connector/JoinCondition.java | 15 + plugin/trino-base-jdbc/pom.xml | 13 + .../io/trino/plugin/jdbc/BaseJdbcClient.java | 33 ++ .../trino/plugin/jdbc/CachingJdbcClient.java | 14 + .../plugin/jdbc/ForwardingJdbcClient.java | 14 + .../java/io/trino/plugin/jdbc/JdbcClient.java | 10 + .../trino/plugin/jdbc/JdbcJoinCondition.java | 47 +++ .../io/trino/plugin/jdbc/JdbcMetadata.java | 111 +++++++ .../trino/plugin/jdbc/JdbcMetadataConfig.java | 20 ++ .../jdbc/JdbcMetadataSessionProperties.java | 11 + .../io/trino/plugin/jdbc/QueryBuilder.java | 59 ++++ .../plugin/jdbc/jmx/JdbcClientStats.java | 8 + .../jdbc/jmx/StatisticsAwareJdbcClient.java | 14 + .../plugin/jdbc/BaseJdbcConnectorTest.java | 308 ++++++++++++++++++ .../plugin/jdbc/TestJdbcMetadataConfig.java | 3 + .../io/trino/plugin/memsql/MemSqlClient.java | 38 +++ .../memsql/TestMemSqlConnectorTest.java | 27 ++ .../io/trino/plugin/mysql/MySqlClient.java | 36 ++ .../plugin/mysql/BaseMySqlConnectorTest.java | 24 ++ .../io/trino/plugin/oracle/OracleClient.java | 8 + .../oracle/BaseOracleConnectorTest.java | 25 ++ .../plugin/postgresql/PostgreSqlClient.java | 29 ++ .../TestPostgreSqlConnectorTest.java | 50 +++ .../plugin/sqlserver/SqlServerClient.java | 17 + .../sqlserver/TestSqlServerConnectorTest.java | 23 ++ .../io/trino/testing/BaseConnectorTest.java | 23 ++ .../testing/TestingConnectorBehavior.java | 28 ++ 30 files changed, 1056 insertions(+), 30 deletions(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinCondition.java create mode 100644 testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index fe86220254cc..ae76707ba30b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -36,9 +36,11 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.Patterns; import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.tree.BooleanLiteral; import io.trino.sql.tree.ComparisonExpression; @@ -92,6 +94,12 @@ public Pattern getPattern() return PATTERN; } + @Override + public boolean isEnabled(Session session) + { + return isAllowPushdownIntoConnectors(session); + } + @Override public Result apply(JoinNode joinNode, Captures captures, Context context) { @@ -105,7 +113,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) verify(!left.isForDelete() && !right.isForDelete(), "Unexpected Join over for-delete table scan"); Expression effectiveFilter = getEffectiveFilter(joinNode); - FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), context); + FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context); if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) { // TODO add extra filter node above join @@ -156,13 +164,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) Map leftColumnHandlesMapping = joinApplicationResult.get().getLeftColumnHandles(); Map rightColumnHandlesMapping = joinApplicationResult.get().getRightColumnHandles(); - ImmutableMap.Builder newAssignments = ImmutableMap.builder(); - newAssignments.putAll(left.getAssignments().entrySet().stream().collect(toImmutableMap( + ImmutableMap.Builder assignmentsBuilder = ImmutableMap.builder(); + assignmentsBuilder.putAll(left.getAssignments().entrySet().stream().collect(toImmutableMap( Map.Entry::getKey, entry -> leftColumnHandlesMapping.get(entry.getValue())))); - newAssignments.putAll(right.getAssignments().entrySet().stream().collect(toImmutableMap( + assignmentsBuilder.putAll(right.getAssignments().entrySet().stream().collect(toImmutableMap( Map.Entry::getKey, entry -> rightColumnHandlesMapping.get(entry.getValue())))); + Map assignments = assignmentsBuilder.build(); // convert enforced constraint JoinNode.Type joinType = joinNode.getType(); @@ -176,7 +185,17 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) .putAll(rightConstraint.getDomains().orElseThrow()) .build()); - return Result.ofPlanNode(new TableScanNode(joinNode.getId(), handle, joinNode.getOutputSymbols(), newAssignments.build(), newEnforcedConstraint, false)); + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new TableScanNode( + joinNode.getId(), + handle, + ImmutableList.copyOf(assignments.keySet()), + assignments, + newEnforcedConstraint, + false), + Assignments.identity(joinNode.getOutputSymbols()))); } private JoinStatistics getJoinStatistics(JoinNode join, TableScanNode left, TableScanNode right, Context context) @@ -228,21 +247,6 @@ private TupleDomain deriveConstraint(TupleDomain sou }); } - @Override - public boolean isEnabled(Session session) - { - return isAllowPushdownIntoConnectors(session); - } - - private TupleDomain transformToNewAssignments(TupleDomain tupleDomain, Map newAssignments) - { - return tupleDomain.transform(handle -> { - ColumnHandle newHandle = newAssignments.get(handle); - checkArgument(newHandle != null, "Mapping not found for handle %s", handle); - return newHandle; - }); - } - public Expression getEffectiveFilter(JoinNode node) { Expression effectiveFilter = and(node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(toImmutableList())); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index 3ee9c686171d..7bb8b83bc9e5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -54,6 +54,7 @@ import static io.trino.spi.predicate.Domain.onlyNull; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester; import static io.trino.sql.planner.plan.JoinNode.Type.FULL; @@ -174,7 +175,8 @@ public void testPushJoinIntoTableScan(JoinNode.Type joinType, Optional JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.equals(((MockConnectorTableHandle) tableHandle).getTableName()), - expectedConstraint, - ImmutableMap.of())); + project( + tableScan( + tableHandle -> JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.equals(((MockConnectorTableHandle) tableHandle).getTableName()), + expectedConstraint, + ImmutableMap.of()))); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index 91da20c4436f..b75dea64fea4 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -385,6 +385,22 @@ public QueryAssert isFullyPushedDown() public final QueryAssert isNotFullyPushedDown(Class... retainedNodes) { checkArgument(retainedNodes.length > 0, "No retainedNodes"); + PlanMatchPattern expectedPlan = PlanMatchPattern.node(TableScanNode.class); + for (Class retainedNode : ImmutableList.copyOf(retainedNodes).reverse()) { + expectedPlan = PlanMatchPattern.node(retainedNode, expectedPlan); + } + return isNotFullyPushedDown(expectedPlan); + } + + /** + * Verifies query is not fully pushed down and verifies the results are the same as when the pushdown is fully disabled. + *

+ * Note: the primary intent of this assertion is to ensure the test is updated to {@link #isFullyPushedDown()} + * when pushdown capabilities are improved. + */ + public final QueryAssert isNotFullyPushedDown(PlanMatchPattern retainedSubplan) + { + PlanMatchPattern expectedPlan = PlanMatchPattern.anyTree(retainedSubplan); // Compare the results with pushdown disabled, so that explicit matches() call is not needed verifyResultsWithPushdownDisabled(); @@ -392,11 +408,6 @@ public final QueryAssert isNotFullyPushedDown(Class... retai transaction(runner.getTransactionManager(), runner.getAccessControl()) .execute(session, session -> { Plan plan = runner.createPlan(session, query, WarningCollector.NOOP); - PlanMatchPattern expectedPlan = PlanMatchPattern.node(TableScanNode.class); - for (Class retainedNode : ImmutableList.copyOf(retainedNodes).reverse()) { - expectedPlan = PlanMatchPattern.node(retainedNode, expectedPlan); - } - expectedPlan = PlanMatchPattern.anyTree(expectedPlan); assertPlan( session, runner.getMetadata(), diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java index 53a465ee5ec3..4d2b27c12e20 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/JoinCondition.java @@ -56,6 +56,21 @@ public JoinCondition(Operator operator, ConnectorExpression leftExpression, Conn this.rightExpression = requireNonNull(rightExpression, "rightExpression is null"); } + public Operator getOperator() + { + return operator; + } + + public ConnectorExpression getLeftExpression() + { + return leftExpression; + } + + public ConnectorExpression getRightExpression() + { + return rightExpression; + } + @Override public boolean equals(Object o) { diff --git a/plugin/trino-base-jdbc/pom.xml b/plugin/trino-base-jdbc/pom.xml index a38cbe6fbb7d..6e453538a36e 100644 --- a/plugin/trino-base-jdbc/pom.xml +++ b/plugin/trino-base-jdbc/pom.xml @@ -142,6 +142,13 @@ test + + io.trino + trino-main + test-jar + test + + io.trino trino-spi @@ -185,6 +192,12 @@ test + + org.jetbrains + annotations + test + + org.testng testng diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index 2bf7db3a9955..563bac6ed82b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -30,6 +30,7 @@ import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.FixedSplitSource; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.TableNotFoundException; @@ -445,6 +446,38 @@ public PreparedQuery prepareQuery( } } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + for (JdbcJoinCondition joinCondition : joinConditions) { + if (!isSupportedJoinCondition(joinCondition)) { + return Optional.empty(); + } + } + + QueryBuilder queryBuilder = new QueryBuilder(this); + return Optional.of(queryBuilder.prepareJoinQuery( + session, + joinType, + leftSource, + rightSource, + joinConditions, + leftAssignments, + rightAssignments)); + } + + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + return false; + } + @Override public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 48e362ed5e51..9f5f9cb5136d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -27,6 +27,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -205,6 +206,19 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return delegate.buildSql(session, connection, split, table, columns); } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + return delegate.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments); + } + @Override public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 86dd77628a9e..61f6836fc211 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -19,6 +19,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -154,6 +155,19 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return delegate().buildSql(session, connection, split, tableHandle, columnHandles); } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + return delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments); + } + @Override public JdbcOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index f3d2792ffbde..64f56b9b4b02 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -20,6 +20,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -93,6 +94,15 @@ PreparedQuery prepareQuery( PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List columns) throws SQLException; + Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments); + boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder); boolean isTopNLimitGuaranteed(ConnectorSession session); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinCondition.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinCondition.java new file mode 100644 index 000000000000..f8b23c4a5548 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcJoinCondition.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc; + +import io.trino.spi.connector.JoinCondition.Operator; + +import static java.util.Objects.requireNonNull; + +public class JdbcJoinCondition +{ + private final JdbcColumnHandle leftColumn; + private final Operator operator; + private final JdbcColumnHandle rightColumn; + + public JdbcJoinCondition(JdbcColumnHandle leftColumn, Operator operator, JdbcColumnHandle rightColumn) + { + this.leftColumn = requireNonNull(leftColumn, "leftColumn is null"); + this.operator = requireNonNull(operator, "operator is null"); + this.rightColumn = requireNonNull(rightColumn, "rightColumn is null"); + } + + public JdbcColumnHandle getLeftColumn() + { + return leftColumn; + } + + public Operator getOperator() + { + return operator; + } + + public JdbcColumnHandle getRightColumn() + { + return rightColumn; + } +} diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java index a75220a0d28b..f5d831abbccf 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadata.java @@ -35,6 +35,10 @@ import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; +import io.trino.spi.connector.JoinApplicationResult; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.ProjectionApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -62,13 +66,17 @@ import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Functions.identity; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; +import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isAggregationPushdownEnabled; +import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled; import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED; +import static java.lang.Math.max; import static java.util.Objects.requireNonNull; public class JdbcMetadata @@ -298,6 +306,109 @@ public Optional> applyAggrega return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of())); } + @Override + public Optional> applyJoin( + ConnectorSession session, + JoinType joinType, + ConnectorTableHandle left, + ConnectorTableHandle right, + List joinConditions, + Map leftAssignments, + Map rightAssignments, + JoinStatistics statistics) + { + if (!isJoinPushdownEnabled(session)) { + return Optional.empty(); + } + + JdbcTableHandle leftHandle = flushAttributesAsQuery(session, (JdbcTableHandle) left); + JdbcTableHandle rightHandle = flushAttributesAsQuery(session, (JdbcTableHandle) right); + int nextSyntheticColumnId = max(leftHandle.getNextSyntheticColumnId(), rightHandle.getNextSyntheticColumnId()); + + ImmutableMap.Builder newLeftColumnsBuilder = ImmutableMap.builder(); + for (JdbcColumnHandle column : jdbcClient.getColumns(session, leftHandle)) { + newLeftColumnsBuilder.put(column, JdbcColumnHandle.builderFrom(column) + .setColumnName(column.getColumnName() + "_" + nextSyntheticColumnId) + .build()); + nextSyntheticColumnId++; + } + Map newLeftColumns = newLeftColumnsBuilder.build(); + + ImmutableMap.Builder newRightColumnsBuilder = ImmutableMap.builder(); + for (JdbcColumnHandle column : jdbcClient.getColumns(session, rightHandle)) { + newRightColumnsBuilder.put(column, JdbcColumnHandle.builderFrom(column) + .setColumnName(column.getColumnName() + "_" + nextSyntheticColumnId) + .build()); + nextSyntheticColumnId++; + } + Map newRightColumns = newRightColumnsBuilder.build(); + + ImmutableList.Builder jdbcJoinConditions = ImmutableList.builder(); + for (JoinCondition joinCondition : joinConditions) { + Optional leftColumn = getVariableColumnHandle(leftAssignments, joinCondition.getLeftExpression()); + Optional rightColumn = getVariableColumnHandle(rightAssignments, joinCondition.getRightExpression()); + if (leftColumn.isEmpty() || rightColumn.isEmpty()) { + return Optional.empty(); + } + jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get())); + } + + Optional joinQuery = jdbcClient.implementJoin( + session, + joinType, + asPreparedQuery(leftHandle), + asPreparedQuery(rightHandle), + jdbcJoinConditions.build(), + newRightColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())), + newLeftColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName()))); + + if (joinQuery.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new JoinApplicationResult<>( + new JdbcTableHandle( + new JdbcQueryRelationHandle(joinQuery.get()), + TupleDomain.all(), + Optional.empty(), + OptionalLong.empty(), + Optional.of( + ImmutableList.builder() + .addAll(newLeftColumns.values()) + .addAll(newRightColumns.values()) + .build()), + nextSyntheticColumnId), + ImmutableMap.copyOf(newLeftColumns), + ImmutableMap.copyOf(newRightColumns))); + } + + private static Optional getVariableColumnHandle(Map assignments, ConnectorExpression expression) + { + requireNonNull(assignments, "assignments is null"); + requireNonNull(expression, "expression is null"); + if (!(expression instanceof Variable)) { + return Optional.empty(); + } + + String name = ((Variable) expression).getName(); + ColumnHandle columnHandle = assignments.get(name); + verifyNotNull(columnHandle, "No assignment for %s", name); + return Optional.of(((JdbcColumnHandle) columnHandle)); + } + + private static PreparedQuery asPreparedQuery(JdbcTableHandle tableHandle) + { + checkArgument( + tableHandle.getConstraint().equals(TupleDomain.all()) && + tableHandle.getLimit().isEmpty() && + tableHandle.getRelationHandle() instanceof JdbcQueryRelationHandle, + "Handle is not a plain query: %s", + tableHandle); + return ((JdbcQueryRelationHandle) tableHandle.getRelationHandle()).getPreparedQuery(); + } + @Override public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java index 6897ace5ac40..e9e0d647f792 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataConfig.java @@ -22,6 +22,13 @@ public class JdbcMetadataConfig { private boolean allowDropTable; + /* + * Join pushdown is disabled by default as this is the safer option. + * Pushing down a join which substantially increases the row count vs + * sizes of left and right table separately, may incur huge cost both + * in terms of performance and money due to an increased network traffic. + */ + private boolean joinPushdownEnabled; private boolean aggregationPushdownEnabled = true; // TODO: https://github.com/trinodb/trino/issues/7031 @@ -47,6 +54,19 @@ public JdbcMetadataConfig setAllowDropTable(boolean allowDropTable) return this; } + public boolean isJoinPushdownEnabled() + { + return joinPushdownEnabled; + } + + @Config("experimental.join-pushdown.enabled") + @ConfigDescription("Enable join pushdown") + public JdbcMetadataConfig setJoinPushdownEnabled(boolean joinPushdownEnabled) + { + this.joinPushdownEnabled = joinPushdownEnabled; + return this; + } + public boolean isAggregationPushdownEnabled() { return aggregationPushdownEnabled; diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java index 628f3c9e6d30..41653cdebf51 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcMetadataSessionProperties.java @@ -31,6 +31,7 @@ public class JdbcMetadataSessionProperties implements SessionPropertiesProvider { + public static final String JOIN_PUSHDOWN_ENABLED = "join_pushdown_enabled"; public static final String AGGREGATION_PUSHDOWN_ENABLED = "aggregation_pushdown_enabled"; public static final String TOPN_PUSHDOWN_ENABLED = "topn_pushdown_enabled"; public static final String DOMAIN_COMPACTION_THRESHOLD = "domain_compaction_threshold"; @@ -42,6 +43,11 @@ public JdbcMetadataSessionProperties(JdbcMetadataConfig jdbcMetadataConfig, @Max { validateDomainCompactionThreshold(jdbcMetadataConfig.getDomainCompactionThreshold(), maxDomainCompactionThreshold); properties = ImmutableList.>builder() + .add(booleanProperty( + JOIN_PUSHDOWN_ENABLED, + "Enable join pushdown", + jdbcMetadataConfig.isJoinPushdownEnabled(), + false)) .add(booleanProperty( AGGREGATION_PUSHDOWN_ENABLED, "Enable aggregation pushdown", @@ -67,6 +73,11 @@ public List> getSessionProperties() return properties; } + public static boolean isJoinPushdownEnabled(ConnectorSession session) + { + return session.getProperty(JOIN_PUSHDOWN_ENABLED, Boolean.class); + } + public static boolean isAggregationPushdownEnabled(ConnectorSession session) { return session.getProperty(AGGREGATION_PUSHDOWN_ENABLED, Boolean.class); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java index d5119c8613f7..808d7c2dea8d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/QueryBuilder.java @@ -20,6 +20,7 @@ import io.airlift.slice.Slice; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinType; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; @@ -99,6 +100,64 @@ public PreparedQuery prepareQuery( return new PreparedQuery(sql, accumulator.build()); } + public PreparedQuery prepareJoinQuery( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map leftAssignments, + Map rightAssignments) + { + // Verify assignments are present. This is safe assumption as join conditions are not pruned, and simplifies the code here. + verify(!leftAssignments.isEmpty(), "leftAssignments is empty"); + verify(!rightAssignments.isEmpty(), "rightAssignments is empty"); + // Joins wih no conditions are not pushed down, so it is a same assumption and simplifies the code here + verify(!joinConditions.isEmpty(), "joinConditions is empty"); + + String query = format( + "SELECT %s, %s FROM (%s) l %s (%s) r ON %s", + formatAssignments("l", leftAssignments), + formatAssignments("r", rightAssignments), + leftSource.getQuery(), + formatJoinType(joinType), + rightSource.getQuery(), + joinConditions.stream() + .map(condition -> format( + "l.%s %s r.%s", + client.quoted(condition.getLeftColumn().getColumnName()), + condition.getOperator().getValue(), + client.quoted(condition.getRightColumn().getColumnName()))) + .collect(joining(" AND "))); + List parameters = ImmutableList.builder() + .addAll(leftSource.getParameters()) + .addAll(rightSource.getParameters()) + .build(); + return new PreparedQuery(query, parameters); + } + + protected String formatAssignments(String relationAlias, Map assignments) + { + return assignments.entrySet().stream() + .map(entry -> format("%s.%s AS %s", relationAlias, client.quoted(entry.getKey().getColumnName()), client.quoted(entry.getValue()))) + .collect(joining(", ")); + } + + protected static String formatJoinType(JoinType joinType) + { + switch (joinType) { + case INNER: + return "INNER JOIN"; + case LEFT_OUTER: + return "LEFT JOIN"; + case RIGHT_OUTER: + return "RIGHT JOIN"; + case FULL_OUTER: + return "FULL JOIN"; + } + throw new IllegalStateException("Unsupported join type: " + joinType); + } + public PreparedStatement prepareStatement( ConnectorSession session, Connection connection, diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java index 4020bc33d2ef..4e61f56e22f5 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/JdbcClientStats.java @@ -25,6 +25,7 @@ public final class JdbcClientStats private final JdbcApiStats buildInsertSql = new JdbcApiStats(); private final JdbcApiStats prepareQuery = new JdbcApiStats(); private final JdbcApiStats buildSql = new JdbcApiStats(); + private final JdbcApiStats implementJoin = new JdbcApiStats(); private final JdbcApiStats commitCreateTable = new JdbcApiStats(); private final JdbcApiStats createSchema = new JdbcApiStats(); private final JdbcApiStats createTable = new JdbcApiStats(); @@ -101,6 +102,13 @@ public JdbcApiStats getBuildSql() return buildSql; } + @Managed + @Nested + public JdbcApiStats getImplementJoin() + { + return implementJoin; + } + @Managed @Nested public JdbcApiStats getCommitCreateTable() diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 43409483d248..b04a8fded812 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -17,6 +17,7 @@ import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcTableHandle; @@ -31,6 +32,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.SystemTable; @@ -172,6 +174,18 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio return stats.getBuildSql().wrap(() -> delegate().buildSql(session, connection, split, tableHandle, columnHandles)); } + @Override + public Optional implementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + return stats.getImplementJoin().wrap(() -> delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments)); + } + @Override public void setColumnComment(ConnectorSession session, JdbcTableHandle handle, JdbcColumnHandle column, Optional comment) { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index c58de234babc..70a37d3723d7 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -13,10 +13,318 @@ */ package io.trino.plugin.jdbc; +import io.trino.Session; +import io.trino.spi.connector.JoinCondition; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.query.QueryAssertions.QueryAssert; import io.trino.testing.BaseConnectorTest; +import io.trino.testing.sql.TestTable; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.stream.Stream; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.MoreCollectors.toOptional; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_AGGREGATION_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_LIMIT_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_TOPN_PUSHDOWN; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; public abstract class BaseJdbcConnectorTest extends BaseConnectorTest { // TODO move common tests from connector-specific classes here + + @Test + public void testJoinPushdownDisabled() + { + // If join pushdown gets enabled by default, this test should use a session with join pushdown disabled + Session noJoinPushdown = Session.builder(getSession()) + // Disable dynamic filtering so that expected plans in case of no pushdown remain "simple" + .setSystemProperty("enable_dynamic_filtering", "false") + // Disable optimized hash generation so that expected plans in case of no pushdown remain "simple" + .setSystemProperty("optimize_hash_generation", "false") + .build(); + + PlanMatchPattern partitionedJoinOverTableScans = node(JoinNode.class, + exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, + node(TableScanNode.class)), + exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, + exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPARTITION, + node(TableScanNode.class)))); + + assertThat(query(noJoinPushdown, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")) + .isNotFullyPushedDown(partitionedJoinOverTableScans); + } + + /** + * Verify !SUPPORTS_JOIN_PUSHDOWN declaration is true. + */ + @Test + public void verifySupportsJoinPushdownDeclaration() + { + if (hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) { + // Covered by testJoinPushdown + return; + } + + assertThat(query(joinPushdownEnabled(getSession()), "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")) + .isNotFullyPushedDown( + node(JoinNode.class, + anyTree(node(TableScanNode.class)), + anyTree(node(TableScanNode.class)))); + } + + @Test + public void testJoinPushdown() + { + PlanMatchPattern joinOverTableScans = + node(JoinNode.class, + anyTree(node(TableScanNode.class)), + anyTree(node(TableScanNode.class))); + + PlanMatchPattern broadcastJoinOverTableScans = + node(JoinNode.class, + node(TableScanNode.class), + exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, + exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.REPLICATE, + node(TableScanNode.class)))); + + if (!hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) { + assertThat(query("SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")) + .isNotFullyPushedDown(joinOverTableScans); + return; + } + + Session session = joinPushdownEnabled(getSession()); + + // Disable DF here for the sake of negative test cases' expected plan. With DF enabled, some operators return in DF's FilterNode and some do not. + Session withoutDynamicFiltering = Session.builder(session) + .setSystemProperty("enable_dynamic_filtering", "false") + .build(); + + String notDistinctOperator = "IS NOT DISTINCT FROM"; + List nonEqualities = Stream.concat( + Stream.of(JoinCondition.Operator.values()) + .filter(operator -> operator != JoinCondition.Operator.EQUAL) + .map(JoinCondition.Operator::getValue), + Stream.of(notDistinctOperator)) + .collect(toImmutableList()); + + try (TestTable nationLowercaseTable = new TestTable( + // If a connector supports Join pushdown, but does not allow CTAS, we need to make the table creation here overridable. + getQueryRunner()::execute, + "nation_lowercase", + "AS SELECT nationkey, lower(name) name, regionkey FROM nation")) { + // basic case + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey")).isFullyPushedDown(); + + // join over different columns + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // pushdown when using USING + assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r USING(regionkey)")).isFullyPushedDown(); + + // varchar equality predicate + assertConditionallyPushedDown( + session, + "SELECT n.name, n2.regionkey FROM nation n JOIN nation n2 ON n.name = n2.name", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), + joinOverTableScans); + assertConditionallyPushedDown( + session, + format("SELECT n.name, nl.regionkey FROM nation n JOIN %s nl ON n.name = nl.name", nationLowercaseTable.getName()), + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), + joinOverTableScans); + + // multiple bigint predicates + assertThat(query(session, "SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey")) + .isFullyPushedDown(); + + // inequality + for (String operator : nonEqualities) { + // bigint inequality predicate + assertThat(query(withoutDynamicFiltering, format("SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey %s r.regionkey", operator))) + // Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes + .isNotFullyPushedDown(broadcastJoinOverTableScans); + + // varchar inequality predicate + assertThat(query(withoutDynamicFiltering, format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.name %s nl.name", nationLowercaseTable.getName(), operator))) + // Currently no pushdown as inequality predicate is removed from Join to maintain Cross Join and Filter as separate nodes + .isNotFullyPushedDown(broadcastJoinOverTableScans); + } + + // inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertConditionallyPushedDown( + session, + format("SELECT n.name, c.name FROM nation n JOIN customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", operator), + expectJoinPushdown(operator), + joinOverTableScans); + } + + // varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join + for (String operator : nonEqualities) { + assertConditionallyPushedDown( + session, + format("SELECT n.name, nl.name FROM nation n JOIN %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", nationLowercaseTable.getName(), operator), + expectVarcharJoinPushdown(operator), + joinOverTableScans); + } + + // LEFT JOIN + assertThat(query(session, "SELECT r.name, n.name FROM nation n LEFT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + assertThat(query(session, "SELECT r.name, n.name FROM region r LEFT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // RIGHT JOIN + assertThat(query(session, "SELECT r.name, n.name FROM nation n RIGHT JOIN region r ON n.nationkey = r.regionkey")).isFullyPushedDown(); + assertThat(query(session, "SELECT r.name, n.name FROM region r RIGHT JOIN nation n ON n.nationkey = r.regionkey")).isFullyPushedDown(); + + // FULL JOIN + assertConditionallyPushedDown( + session, + "SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.nationkey = r.regionkey", + hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN), + joinOverTableScans); + + // Join over a (double) predicate + assertThat(query(session, "" + + "SELECT c.name, n.name " + + "FROM (SELECT * FROM customer WHERE acctbal > 8000) c " + + "JOIN nation n ON c.custkey = n.nationkey")) + .isFullyPushedDown(); + + // Join over a varchar equality predicate + assertConditionallyPushedDown( + session, + "SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + + "JOIN nation n ON c.custkey = n.nationkey", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY), + joinOverTableScans); + + // Join over a varchar inequality predicate + assertConditionallyPushedDown( + session, + "SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " + + "JOIN nation n ON c.custkey = n.nationkey", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY), + joinOverTableScans); + + // join over aggregation + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " + + "JOIN region r ON n.rk = r.regionkey", + hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN), + joinOverTableScans); + + // join over LIMIT + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " + + "JOIN region r ON n.nationkey = r.regionkey", + hasBehavior(SUPPORTS_LIMIT_PUSHDOWN), + joinOverTableScans); + + // join over TopN + assertConditionallyPushedDown( + session, + "SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " + + "JOIN region r ON n.nationkey = r.regionkey", + hasBehavior(SUPPORTS_TOPN_PUSHDOWN), + joinOverTableScans); + + // join over join + assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey")) + .isFullyPushedDown(); + } + } + + private void assertConditionallyPushedDown( + Session session, + @Language("SQL") String query, + boolean condition, + PlanMatchPattern otherwiseExpected) + { + QueryAssert queryAssert = assertThat(query(session, query)); + if (condition) { + queryAssert.isFullyPushedDown(); + } + else { + queryAssert.isNotFullyPushedDown(otherwiseExpected); + } + } + + private boolean expectJoinPushdown(String operator) + { + if ("IS NOT DISTINCT FROM".equals(operator)) { + // TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM + return false; + } + switch (toJoinConditionOperator(operator)) { + case EQUAL: + case NOT_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return true; + case IS_DISTINCT_FROM: + return hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM); + } + throw new AssertionError(); // unreachable + } + + private boolean expectVarcharJoinPushdown(String operator) + { + if ("IS NOT DISTINCT FROM".equals(operator)) { + // TODO (https://github.com/trinodb/trino/issues/6967) support join pushdown for IS NOT DISTINCT FROM + return false; + } + switch (toJoinConditionOperator(operator)) { + case EQUAL: + case NOT_EQUAL: + return hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY); + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY); + case IS_DISTINCT_FROM: + return hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM) && hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY); + } + throw new AssertionError(); // unreachable + } + + private JoinCondition.Operator toJoinConditionOperator(String operator) + { + return Stream.of(JoinCondition.Operator.values()) + .filter(joinOperator -> joinOperator.getValue().equals(operator)) + .collect(toOptional()) + .orElseThrow(() -> new IllegalArgumentException("Not found: " + operator)); + } + + protected Session joinPushdownEnabled(Session session) + { + // If join pushdown gets enabled by default, tests should use default session + verify(!new JdbcMetadataConfig().isJoinPushdownEnabled()); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_enabled", "true") + .build(); + } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java index 9828eb3986d6..ce7930506395 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcMetadataConfig.java @@ -29,6 +29,7 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(JdbcMetadataConfig.class) .setAllowDropTable(false) + .setJoinPushdownEnabled(false) .setAggregationPushdownEnabled(true) .setTopNPushdownEnabled(false) .setDomainCompactionThreshold(32)); @@ -39,6 +40,7 @@ public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() .put("allow-drop-table", "true") + .put("experimental.join-pushdown.enabled", "true") .put("aggregation-pushdown.enabled", "false") .put("domain-compaction-threshold", "42") .put("topn-pushdown.enabled", "true") @@ -46,6 +48,7 @@ public void testExplicitPropertyMappings() JdbcMetadataConfig expected = new JdbcMetadataConfig() .setAllowDropTable(true) + .setJoinPushdownEnabled(true) .setAggregationPushdownEnabled(false) .setTopNPushdownEnabled(true) .setDomainCompactionThreshold(42); diff --git a/plugin/trino-memsql/src/main/java/io/trino/plugin/memsql/MemSqlClient.java b/plugin/trino-memsql/src/main/java/io/trino/plugin/memsql/MemSqlClient.java index 76dd97e82b9a..a0f401a57384 100644 --- a/plugin/trino-memsql/src/main/java/io/trino/plugin/memsql/MemSqlClient.java +++ b/plugin/trino-memsql/src/main/java/io/trino/plugin/memsql/MemSqlClient.java @@ -19,14 +19,19 @@ import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.WriteMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.type.CharType; import io.trino.spi.type.Decimals; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; @@ -43,8 +48,10 @@ import java.sql.Types; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.google.common.base.Verify.verify; import static io.airlift.slice.Slices.utf8Slice; @@ -292,6 +299,37 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + if (joinType == JoinType.FULL_OUTER) { + // Not supported in MemSQL + return Optional.empty(); + } + return super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments); + } + + @Override + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + if (joinCondition.getOperator() == JoinCondition.Operator.IS_DISTINCT_FROM) { + // Not supported in MemSQL + return false; + } + + // Remote database can be case insensitive. + return Stream.of(joinCondition.getLeftColumn(), joinCondition.getRightColumn()) + .map(JdbcColumnHandle::getColumnType) + .noneMatch(type -> type instanceof CharType || type instanceof VarcharType); + } + private static Optional getUnsignedMapping(JdbcTypeHandle typeHandle) { if (typeHandle.getJdbcTypeName().isEmpty()) { diff --git a/plugin/trino-memsql/src/test/java/io/trino/plugin/memsql/TestMemSqlConnectorTest.java b/plugin/trino-memsql/src/test/java/io/trino/plugin/memsql/TestMemSqlConnectorTest.java index 1d3a7c8063e0..de358453b8ce 100644 --- a/plugin/trino-memsql/src/test/java/io/trino/plugin/memsql/TestMemSqlConnectorTest.java +++ b/plugin/trino-memsql/src/test/java/io/trino/plugin/memsql/TestMemSqlConnectorTest.java @@ -20,6 +20,7 @@ import io.trino.testing.MaterializedResult; import io.trino.testing.MaterializedRow; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; import org.testng.SkipException; @@ -90,6 +91,32 @@ protected boolean supportsCommentOnColumn() return false; } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_TOPN_PUSHDOWN: + return false; + + case SUPPORTS_AGGREGATION_PUSHDOWN: + return false; + + case SUPPORTS_JOIN_PUSHDOWN: + return true; + + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + @Override protected TestTable createTableWithDefaultColumns() { diff --git a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java index 2737a0ff49cb..7dc1344ae8e0 100644 --- a/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java +++ b/plugin/trino-mysql/src/main/java/io/trino/plugin/mysql/MySqlClient.java @@ -21,8 +21,10 @@ import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.WriteMapping; import io.trino.plugin.jdbc.expression.AggregateFunctionRewriter; import io.trino.plugin.jdbc.expression.AggregateFunctionRule; @@ -41,6 +43,8 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -64,6 +68,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -461,6 +466,37 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + @Override + public Optional implementJoin( + ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments) + { + if (joinType == JoinType.FULL_OUTER) { + // Not supported in MySQL + return Optional.empty(); + } + return super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments); + } + + @Override + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + if (joinCondition.getOperator() == JoinCondition.Operator.IS_DISTINCT_FROM) { + // Not supported in MySQL + return false; + } + + // Remote database can be case insensitive. + return Stream.of(joinCondition.getLeftColumn(), joinCondition.getRightColumn()) + .map(JdbcColumnHandle::getColumnType) + .noneMatch(type -> type instanceof CharType || type instanceof VarcharType); + } + private ColumnMapping jsonColumnMapping() { return ColumnMapping.sliceMapping( diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java index fb48372f0e64..dbba4f148428 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/BaseMySqlConnectorTest.java @@ -23,6 +23,7 @@ import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.MaterializedResult; import io.trino.testing.MaterializedRow; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; @@ -75,6 +76,29 @@ protected boolean supportsCommentOnColumn() return false; } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_TOPN_PUSHDOWN: + return false; + + case SUPPORTS_JOIN_PUSHDOWN: + return true; + + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + protected abstract SqlExecutor getMySqlExecutor(); @Override diff --git a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java index 7a82eae2b2d7..f19cbb980cb7 100644 --- a/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java +++ b/plugin/trino-oracle/src/main/java/io/trino/plugin/oracle/OracleClient.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DoubleWriteFunction; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongWriteFunction; @@ -29,6 +30,7 @@ import io.trino.plugin.jdbc.WriteMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.CharType; import io.trino.spi.type.Chars; @@ -337,6 +339,12 @@ else if (precision > Decimals.MAX_PRECISION || actualPrecision <= 0) { return Optional.empty(); } + @Override + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + return joinCondition.getOperator() != JoinCondition.Operator.IS_DISTINCT_FROM; + } + public static LongWriteFunction oracleDateWriteFunction() { return (statement, index, value) -> { diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java index 1e83bbff5c7a..8ef28165f7fc 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.MaterializedResult; import io.trino.testing.ResultWithQueryId; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.SqlExecutor; import io.trino.testing.sql.TestTable; import org.testng.annotations.Test; @@ -65,6 +66,30 @@ protected boolean supportsCommentOnTable() return false; } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_LIMIT_PUSHDOWN: + return false; + + case SUPPORTS_TOPN_PUSHDOWN: + return false; + + case SUPPORTS_AGGREGATION_PUSHDOWN: + return false; + + case SUPPORTS_JOIN_PUSHDOWN: + return true; + + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + @Override public void testCreateSchema() { diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index fc1c86f886d9..776c69c8cdca 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -27,6 +27,7 @@ import io.trino.plugin.jdbc.DoubleReadFunction; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongReadFunction; @@ -65,6 +66,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortItem; import io.trino.spi.connector.TableNotFoundException; @@ -109,6 +111,7 @@ import java.util.Optional; import java.util.UUID; import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -710,6 +713,32 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + @Override + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + boolean isVarchar = Stream.of(joinCondition.getLeftColumn(), joinCondition.getRightColumn()) + .map(JdbcColumnHandle::getColumnType) + .anyMatch(type -> type instanceof CharType || type instanceof VarcharType); + if (isVarchar) { + // PostgreSQL is case sensitive by default, but orders varchars differently + JoinCondition.Operator operator = joinCondition.getOperator(); + switch (operator) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + break; + case EQUAL: + case NOT_EQUAL: + case IS_DISTINCT_FROM: + return true; + } + return false; + } + + return true; + } + private static ColumnMapping charColumnMapping(int charLength) { if (charLength > CharType.MAX_LENGTH) { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 1fe848bd6b6d..63c532d7ad1b 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -19,9 +19,12 @@ import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.JdbcSqlExecutor; import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; @@ -37,6 +40,8 @@ import static io.trino.SystemSessionProperties.USE_MARK_DISTINCT; import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; +import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.testing.sql.TestTable.randomTableSuffix; import static java.lang.String.format; import static java.util.stream.Collectors.joining; @@ -96,6 +101,21 @@ protected boolean supportsCommentOnTable() return false; } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_JOIN_PUSHDOWN: + return true; + + default: + return super.hasBehavior(connectorBehavior); + } + } + @Override protected TestTable createTableWithDefaultColumns() { @@ -303,6 +323,20 @@ public void testPredicatePushdown() .matches("VALUES BIGINT '643', 898") .ordered() .isFullyPushedDown(); + + // predicate over join + Session joinPushdownEnabled = joinPushdownEnabled(getSession()); + assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.custkey = n.nationkey WHERE acctbal > 8000")) + .isFullyPushedDown(); + + // varchar predicate over join + assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.custkey = n.nationkey WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea'")) + .isFullyPushedDown(); + assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.custkey = n.nationkey WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea'")) + .isNotFullyPushedDown( + node(JoinNode.class, + anyTree(node(TableScanNode.class)), + anyTree(node(TableScanNode.class)))); } @Test @@ -490,6 +524,14 @@ public void testAggregationPushdown() "GROUP BY clerk")) .isFullyPushedDown(); + // GROUP BY with JOIN + assertThat(query(joinPushdownEnabled(getSession()), "" + + "SELECT n.regionkey, sum(c.acctbal) acctbals " + + "FROM nation n " + + "LEFT JOIN customer c USING (nationkey) " + + "GROUP BY 1")) + .isFullyPushedDown(); + // decimals String schemaName = getSession().getSchema().orElseThrow(); try (TestTable testTable = new TestTable(postgreSqlServer::execute, @@ -762,6 +804,14 @@ public void testLimitPushdown() // with TopN assertThat(query("SELECT * FROM (SELECT regionkey FROM nation ORDER BY name ASC LIMIT 10) LIMIT 5")).isFullyPushedDown(); + + // LIMIT with JOIN + assertThat(query(joinPushdownEnabled(getSession()), "" + + "SELECT n.name, r.name " + + "FROM nation n " + + "LEFT JOIN region r USING (regionkey) " + + "LIMIT 30")) + .isFullyPushedDown(); } /** diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index 3783800fb69d..872ddb98ba39 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -26,6 +26,7 @@ import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcTableHandle; @@ -46,6 +47,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.connector.JoinCondition; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -68,6 +70,7 @@ import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; +import java.util.stream.Stream; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.microsoft.sqlserver.jdbc.SQLServerConnection.TRANSACTION_SNAPSHOT; @@ -371,6 +374,20 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + @Override + protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition) + { + if (joinCondition.getOperator() == JoinCondition.Operator.IS_DISTINCT_FROM) { + // Not supported in SQL Server + return false; + } + + // Remote database can be case insensitive. + return Stream.of(joinCondition.getLeftColumn(), joinCondition.getRightColumn()) + .map(JdbcColumnHandle::getColumnType) + .noneMatch(type -> type instanceof CharType || type instanceof VarcharType); + } + @Override protected String createTableSql(RemoteTableName remoteTableName, List columns, ConnectorTableMetadata tableMetadata) { diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java index 45ea665c22dd..6b39866275ab 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.plan.MarkDistinctNode; import io.trino.sql.planner.plan.ProjectNode; import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; import io.trino.testing.sql.TestTable; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -85,6 +86,28 @@ protected boolean supportsCommentOnColumn() return false; } + @Override + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_TOPN_PUSHDOWN: + return false; + + case SUPPORTS_JOIN_PUSHDOWN: + return true; + + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + @Override protected TestTable createTableWithDefaultColumns() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 54b706bab668..d5efa823ff3d 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -25,6 +25,8 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertContains; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_JOIN_PUSHDOWN; +import static io.trino.testing.TestingConnectorBehavior.SUPPORTS_PREDICATE_PUSHDOWN; import static io.trino.testing.assertions.Assert.assertEquals; import static java.lang.String.join; import static java.util.Collections.nCopies; @@ -39,6 +41,27 @@ public abstract class BaseConnectorTest extends AbstractTestDistributedQueries { + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN); + + case SUPPORTS_JOIN_PUSHDOWN: + // Currently no connector supports Join pushdown by default. JDBC connectors may support Join pushdown and BaseJdbcConnectorTest + // verifies truthfulness of SUPPORTS_JOIN_PUSHDOWN declaration, so it is a safe default. + return false; + + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + return hasBehavior(SUPPORTS_JOIN_PUSHDOWN); + + default: + return true; + } + } + @Test public void testColumnsInReverseOrder() { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java new file mode 100644 index 000000000000..6a715467b24c --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/TestingConnectorBehavior.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +public enum TestingConnectorBehavior +{ + SUPPORTS_PREDICATE_PUSHDOWN, + SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY, + SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY, + SUPPORTS_LIMIT_PUSHDOWN, + SUPPORTS_TOPN_PUSHDOWN, + SUPPORTS_AGGREGATION_PUSHDOWN, + SUPPORTS_JOIN_PUSHDOWN, + SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN, + SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM, + /**/; +}