Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
Expand Down Expand Up @@ -92,6 +94,12 @@ public Pattern<JoinNode> getPattern()
return PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return isAllowPushdownIntoConnectors(session);
}

@Override
public Result apply(JoinNode joinNode, Captures captures, Context context)
{
Expand All @@ -105,7 +113,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
verify(!left.isForDelete() && !right.isForDelete(), "Unexpected Join over for-delete table scan");

Expression effectiveFilter = getEffectiveFilter(joinNode);
FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), context);
FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context);

if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
// TODO add extra filter node above join
Expand Down Expand Up @@ -156,13 +164,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
Map<ColumnHandle, ColumnHandle> leftColumnHandlesMapping = joinApplicationResult.get().getLeftColumnHandles();
Map<ColumnHandle, ColumnHandle> rightColumnHandlesMapping = joinApplicationResult.get().getRightColumnHandles();

ImmutableMap.Builder<Symbol, ColumnHandle> newAssignments = ImmutableMap.builder();
newAssignments.putAll(left.getAssignments().entrySet().stream().collect(toImmutableMap(
ImmutableMap.Builder<Symbol, ColumnHandle> assignmentsBuilder = ImmutableMap.builder();
assignmentsBuilder.putAll(left.getAssignments().entrySet().stream().collect(toImmutableMap(
Map.Entry::getKey,
entry -> leftColumnHandlesMapping.get(entry.getValue()))));
newAssignments.putAll(right.getAssignments().entrySet().stream().collect(toImmutableMap(
assignmentsBuilder.putAll(right.getAssignments().entrySet().stream().collect(toImmutableMap(
Map.Entry::getKey,
entry -> rightColumnHandlesMapping.get(entry.getValue()))));
Map<Symbol, ColumnHandle> assignments = assignmentsBuilder.build();

// convert enforced constraint
JoinNode.Type joinType = joinNode.getType();
Expand All @@ -176,7 +185,17 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
.putAll(rightConstraint.getDomains().orElseThrow())
.build());

return Result.ofPlanNode(new TableScanNode(joinNode.getId(), handle, joinNode.getOutputSymbols(), newAssignments.build(), newEnforcedConstraint, false));
return Result.ofPlanNode(
new ProjectNode(
context.getIdAllocator().getNextId(),
new TableScanNode(
joinNode.getId(),
handle,
ImmutableList.copyOf(assignments.keySet()),
assignments,
newEnforcedConstraint,
false),
Assignments.identity(joinNode.getOutputSymbols())));
}

private JoinStatistics getJoinStatistics(JoinNode join, TableScanNode left, TableScanNode right, Context context)
Expand Down Expand Up @@ -228,21 +247,6 @@ private TupleDomain<ColumnHandle> deriveConstraint(TupleDomain<ColumnHandle> sou
});
}

@Override
public boolean isEnabled(Session session)
{
return isAllowPushdownIntoConnectors(session);
}

private TupleDomain<ColumnHandle> transformToNewAssignments(TupleDomain<ColumnHandle> tupleDomain, Map<ColumnHandle, ColumnHandle> newAssignments)
{
return tupleDomain.transform(handle -> {
ColumnHandle newHandle = newAssignments.get(handle);
checkArgument(newHandle != null, "Mapping not found for handle %s", handle);
return newHandle;
});
}

public Expression getEffectiveFilter(JoinNode node)
{
Expression effectiveFilter = and(node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(toImmutableList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import static io.trino.spi.predicate.Domain.onlyNull;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester;
import static io.trino.sql.planner.plan.JoinNode.Type.FULL;
Expand Down Expand Up @@ -174,7 +175,8 @@ public void testPushJoinIntoTableScan(JoinNode.Type joinType, Optional<Compariso
})
.withSession(MOCK_SESSION)
.matches(
tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName()));
project(
tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName())));
}
}

Expand Down Expand Up @@ -423,10 +425,11 @@ public void testPushJoinIntoTableScanPreservesEnforcedConstraint(JoinNode.Type j
})
.withSession(MOCK_SESSION)
.matches(
tableScan(
tableHandle -> JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.equals(((MockConnectorTableHandle) tableHandle).getTableName()),
expectedConstraint,
ImmutableMap.of()));
project(
tableScan(
tableHandle -> JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.equals(((MockConnectorTableHandle) tableHandle).getTableName()),
expectedConstraint,
ImmutableMap.of())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,18 +385,29 @@ public QueryAssert isFullyPushedDown()
public final QueryAssert isNotFullyPushedDown(Class<? extends PlanNode>... retainedNodes)
{
checkArgument(retainedNodes.length > 0, "No retainedNodes");
PlanMatchPattern expectedPlan = PlanMatchPattern.node(TableScanNode.class);
for (Class<? extends PlanNode> retainedNode : ImmutableList.copyOf(retainedNodes).reverse()) {
expectedPlan = PlanMatchPattern.node(retainedNode, expectedPlan);
}
return isNotFullyPushedDown(expectedPlan);
}

/**
* Verifies query is not fully pushed down and verifies the results are the same as when the pushdown is fully disabled.
* <p>
* <b>Note:</b> the primary intent of this assertion is to ensure the test is updated to {@link #isFullyPushedDown()}
* when pushdown capabilities are improved.
*/
public final QueryAssert isNotFullyPushedDown(PlanMatchPattern retainedSubplan)
{
PlanMatchPattern expectedPlan = PlanMatchPattern.anyTree(retainedSubplan);

// Compare the results with pushdown disabled, so that explicit matches() call is not needed
verifyResultsWithPushdownDisabled();

transaction(runner.getTransactionManager(), runner.getAccessControl())
.execute(session, session -> {
Plan plan = runner.createPlan(session, query, WarningCollector.NOOP);
PlanMatchPattern expectedPlan = PlanMatchPattern.node(TableScanNode.class);
for (Class<? extends PlanNode> retainedNode : ImmutableList.copyOf(retainedNodes).reverse()) {
expectedPlan = PlanMatchPattern.node(retainedNode, expectedPlan);
}
expectedPlan = PlanMatchPattern.anyTree(expectedPlan);
assertPlan(
session,
runner.getMetadata(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ public JoinCondition(Operator operator, ConnectorExpression leftExpression, Conn
this.rightExpression = requireNonNull(rightExpression, "rightExpression is null");
}

public Operator getOperator()
{
return operator;
}

public ConnectorExpression getLeftExpression()
{
return leftExpression;
}

public ConnectorExpression getRightExpression()
{
return rightExpression;
}

@Override
public boolean equals(Object o)
{
Expand Down
13 changes: 13 additions & 0 deletions plugin/trino-base-jdbc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-main</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-spi</artifactId>
Expand Down Expand Up @@ -185,6 +192,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.jetbrains</groupId>
<artifactId>annotations</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.FixedSplitSource;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SortItem;
import io.trino.spi.connector.TableNotFoundException;
Expand Down Expand Up @@ -445,6 +446,38 @@ public PreparedQuery prepareQuery(
}
}

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments)
{
for (JdbcJoinCondition joinCondition : joinConditions) {
if (!isSupportedJoinCondition(joinCondition)) {
return Optional.empty();
}
}

QueryBuilder queryBuilder = new QueryBuilder(this);
return Optional.of(queryBuilder.prepareJoinQuery(
session,
joinType,
leftSource,
rightSource,
joinConditions,
leftAssignments,
rightAssignments));
}

protected boolean isSupportedJoinCondition(JdbcJoinCondition joinCondition)
{
return false;
}

@Override
public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SortItem;
import io.trino.spi.connector.SystemTable;
Expand Down Expand Up @@ -205,6 +206,19 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio
return delegate.buildSql(session, connection, split, table, columns);
}

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments)
{
return delegate.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments);
}

@Override
public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<SortItem> sortOrder)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SortItem;
import io.trino.spi.connector.SystemTable;
Expand Down Expand Up @@ -154,6 +155,19 @@ public PreparedStatement buildSql(ConnectorSession session, Connection connectio
return delegate().buildSql(session, connection, split, tableHandle, columnHandles);
}

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments)
{
return delegate().implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments);
}

@Override
public JdbcOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SortItem;
import io.trino.spi.connector.SystemTable;
Expand Down Expand Up @@ -93,6 +94,15 @@ PreparedQuery prepareQuery(
PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, JdbcTableHandle table, List<JdbcColumnHandle> columns)
throws SQLException;

Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
PreparedQuery rightSource,
List<JdbcJoinCondition> joinConditions,
Map<JdbcColumnHandle, String> rightAssignments,
Map<JdbcColumnHandle, String> leftAssignments);

boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List<SortItem> sortOrder);

boolean isTopNLimitGuaranteed(ConnectorSession session);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.spi.connector.JoinCondition.Operator;

import static java.util.Objects.requireNonNull;

public class JdbcJoinCondition
{
private final JdbcColumnHandle leftColumn;
private final Operator operator;
private final JdbcColumnHandle rightColumn;

public JdbcJoinCondition(JdbcColumnHandle leftColumn, Operator operator, JdbcColumnHandle rightColumn)
{
this.leftColumn = requireNonNull(leftColumn, "leftColumn is null");
this.operator = requireNonNull(operator, "operator is null");
this.rightColumn = requireNonNull(rightColumn, "rightColumn is null");
}

public JdbcColumnHandle getLeftColumn()
{
return leftColumn;
}

public Operator getOperator()
{
return operator;
}

public JdbcColumnHandle getRightColumn()
{
return rightColumn;
}
}
Loading