Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
import io.trino.spi.type.TypeOperators;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.ConnectorExpressions;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.tree.QualifiedName;
import io.trino.transaction.TransactionManager;
Expand Down Expand Up @@ -158,6 +157,7 @@
import static io.trino.metadata.RedirectionAwareTableHandle.noRedirection;
import static io.trino.metadata.RedirectionAwareTableHandle.withRedirectionTo;
import static io.trino.metadata.SignatureBinder.applyBoundVariables;
import static io.trino.plugin.base.expression.ConnectorExpressions.extractVariables;
import static io.trino.spi.ErrorType.EXTERNAL;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
Expand Down Expand Up @@ -2040,7 +2040,7 @@ private void verifyProjection(TableHandle table, List<ConnectorExpression> proje
.map(Assignment::getVariable)
.collect(toImmutableSet());
projections.stream()
.flatMap(connectorExpression -> ConnectorExpressions.extractVariables(connectorExpression).stream())
.flatMap(connectorExpression -> extractVariables(connectorExpression).stream())
.map(Variable::getName)
.filter(variableName -> !assignedVariables.contains(variableName))
.findAny()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,10 @@ public QueryAssert isNotFullyPushedDown(PlanMatchPattern retainedSubplan)

Comment thread
findepi marked this conversation as resolved.
/**
* Verifies join query is not fully pushed down by containing JOIN node.
*
* @deprecated because the method is not tested in BaseQueryAssertionsTest yet
*/
@Deprecated
@CanIgnoreReturnValue
public QueryAssert joinIsNotFullyPushedDown()
{
Expand All @@ -580,6 +583,7 @@ public QueryAssert joinIsNotFullyPushedDown()
.whereIsInstanceOfAny(JoinNode.class)
.findFirst()
.isEmpty()) {
// TODO show then plan when assertions fails (like hasPlan()) and add negative test coverage in BaseQueryAssertionsTest
throw new IllegalStateException("Join node should be present in explain plan, when pushdown is not applied");
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,36 @@
package io.trino.plugin.base.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Streams.stream;
import static io.trino.spi.expression.Constant.TRUE;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.util.Objects.requireNonNull;

public final class ConnectorExpressions
{
private ConnectorExpressions() {}

public static List<Variable> extractVariables(ConnectorExpression expression)
{
return preOrder(expression)
.filter(Variable.class::isInstance)
.map(Variable.class::cast)
.collect(toImmutableList());
}

public static List<ConnectorExpression> extractConjuncts(ConnectorExpression expression)
{
ImmutableList.Builder<ConnectorExpression> resultBuilder = ImmutableList.builder();
Expand All @@ -38,6 +53,10 @@ public static List<ConnectorExpression> extractConjuncts(ConnectorExpression exp

private static void extractConjuncts(ConnectorExpression expression, ImmutableList.Builder<ConnectorExpression> resultBuilder)
{
if (expression.equals(TRUE)) {
// Skip useless conjuncts.
return;
}
if (expression instanceof Call call) {
if (AND_FUNCTION_NAME.equals(call.getFunctionName())) {
for (ConnectorExpression argument : call.getArguments()) {
Expand All @@ -64,4 +83,11 @@ public static ConnectorExpression and(List<ConnectorExpression> expressions)
}
return getOnlyElement(expressions);
}

private static Stream<ConnectorExpression> preOrder(ConnectorExpression expression)
{
return stream(
Traverser.forTree((SuccessorsFunction<ConnectorExpression>) ConnectorExpression::getChildren)
.depthFirstPreOrder(requireNonNull(expression, "expression is null")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,35 @@ protected static Optional<ParameterizedExpression> getAdditionalPredicate(List<P

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
try (Connection connection = this.connectionFactory.openConnection(session)) {
return Optional.of(queryBuilder.prepareJoinQuery(
this,
session,
connection,
joinType,
leftSource,
leftProjections,
rightSource,
rightProjections,
joinConditions));
}
catch (SQLException e) {
throw new TrinoException(JDBC_ERROR, e);
}
}

@Deprecated
@Override
public Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Expand All @@ -540,7 +569,7 @@ public Optional<PreparedQuery> implementJoin(
}

try (Connection connection = this.connectionFactory.openConnection(session)) {
return Optional.of(queryBuilder.prepareJoinQuery(
return Optional.of(queryBuilder.legacyPrepareJoinQuery(
this,
session,
connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ public CallableStatement buildProcedure(ConnectorSession session, Connection con

@Override
public Optional<PreparedQuery> implementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Map<JdbcColumnHandle, String> leftProjections,
PreparedQuery rightSource,
Map<JdbcColumnHandle, String> rightProjections,
List<ParameterizedExpression> joinConditions,
JoinStatistics statistics)
{
return delegate.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics);
}

@Override
public Optional<PreparedQuery> legacyImplementJoin(
ConnectorSession session,
JoinType joinType,
PreparedQuery leftSource,
Expand All @@ -290,7 +304,7 @@ public Optional<PreparedQuery> implementJoin(
Map<JdbcColumnHandle, String> leftAssignments,
JoinStatistics statistics)
{
return delegate.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
return delegate.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
Expand All @@ -94,6 +95,7 @@
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isAggregationPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexExpressionPushdown;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexJoinPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled;
import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert;
Expand Down Expand Up @@ -436,6 +438,120 @@ static JdbcColumnHandle createSyntheticAggregationColumn(AggregateFunction aggre
.build();
}

@Override
public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
ConnectorSession session,
JoinType joinType,
ConnectorTableHandle left,
ConnectorTableHandle right,
ConnectorExpression joinCondition,
Map<String, ColumnHandle> leftAssignments,
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics)
{
if (!isComplexJoinPushdownEnabled(session)) {
// Fallback to the old join pushdown code
return JdbcMetadata.super.applyJoin(
session,
joinType,
left,
right,
joinCondition,
leftAssignments,
rightAssignments,
statistics);
}

if (isTableHandleForProcedure(left) || isTableHandleForProcedure(right)) {
return Optional.empty();
}

if (!isJoinPushdownEnabled(session)) {
Comment thread
findepi marked this conversation as resolved.
return Optional.empty();
}

JdbcTableHandle leftHandle = flushAttributesAsQuery(session, (JdbcTableHandle) left);
JdbcTableHandle rightHandle = flushAttributesAsQuery(session, (JdbcTableHandle) right);

if (!leftHandle.getAuthorization().equals(rightHandle.getAuthorization())) {
return Optional.empty();
}
int nextSyntheticColumnId = max(leftHandle.getNextSyntheticColumnId(), rightHandle.getNextSyntheticColumnId());

ImmutableMap.Builder<JdbcColumnHandle, JdbcColumnHandle> newLeftColumnsBuilder = ImmutableMap.builder();
Comment thread
findepi marked this conversation as resolved.
OptionalInt maxColumnNameLength = jdbcClient.getMaxColumnNameLength(session);
for (JdbcColumnHandle column : jdbcClient.getColumns(session, leftHandle)) {
newLeftColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength));
nextSyntheticColumnId++;
}
Map<JdbcColumnHandle, JdbcColumnHandle> newLeftColumns = newLeftColumnsBuilder.buildOrThrow();

ImmutableMap.Builder<JdbcColumnHandle, JdbcColumnHandle> newRightColumnsBuilder = ImmutableMap.builder();
for (JdbcColumnHandle column : jdbcClient.getColumns(session, rightHandle)) {
newRightColumnsBuilder.put(column, createSyntheticJoinProjectionColumn(column, nextSyntheticColumnId, maxColumnNameLength));
nextSyntheticColumnId++;
}
Map<JdbcColumnHandle, JdbcColumnHandle> newRightColumns = newRightColumnsBuilder.buildOrThrow();

Map<String, ColumnHandle> assignments = ImmutableMap.<String, ColumnHandle>builder()
.putAll(leftAssignments.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> newLeftColumns.get((JdbcColumnHandle) entry.getValue()))))
.putAll(rightAssignments.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> newRightColumns.get((JdbcColumnHandle) entry.getValue()))))
.buildOrThrow();

ImmutableList.Builder<ParameterizedExpression> joinConditions = ImmutableList.builder();
for (ConnectorExpression conjunct : extractConjuncts(joinCondition)) {
Optional<ParameterizedExpression> converted = jdbcClient.convertPredicate(session, conjunct, assignments);
if (converted.isEmpty()) {
return Optional.empty();
}
joinConditions.add(converted.get());
}

Optional<PreparedQuery> joinQuery = jdbcClient.implementJoin(
session,
joinType,
asPreparedQuery(leftHandle),
newLeftColumns.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
asPreparedQuery(rightHandle),
newRightColumns.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
joinConditions.build(),
statistics);

if (joinQuery.isEmpty()) {
return Optional.empty();
}

return Optional.of(new JoinApplicationResult<>(
new JdbcTableHandle(
new JdbcQueryRelationHandle(joinQuery.get()),
TupleDomain.all(),
ImmutableList.of(),
Optional.empty(),
OptionalLong.empty(),
Optional.of(
ImmutableList.<JdbcColumnHandle>builder()
.addAll(newLeftColumns.values())
.addAll(newRightColumns.values())
.build()),
leftHandle.getAllReferencedTables().flatMap(leftReferencedTables ->
rightHandle.getAllReferencedTables().map(rightReferencedTables ->
ImmutableSet.<SchemaTableName>builder()
.addAll(leftReferencedTables)
.addAll(rightReferencedTables)
.build())),
nextSyntheticColumnId,
leftHandle.getAuthorization(),
leftHandle.getUpdateAssignments()),
ImmutableMap.copyOf(newLeftColumns),
ImmutableMap.copyOf(newRightColumns),
precalculateStatisticsForPushdown));
}

@Deprecated
@Override
public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
ConnectorSession session,
Expand Down Expand Up @@ -488,16 +604,16 @@ public Optional<JoinApplicationResult<ConnectorTableHandle>> applyJoin(
jdbcJoinConditions.add(new JdbcJoinCondition(leftColumn.get(), joinCondition.getOperator(), rightColumn.get()));
}

Optional<PreparedQuery> joinQuery = jdbcClient.implementJoin(
Optional<PreparedQuery> joinQuery = jdbcClient.legacyImplementJoin(
session,
joinType,
asPreparedQuery(leftHandle),
asPreparedQuery(rightHandle),
jdbcJoinConditions.build(),
newRightColumns.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())),
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
newLeftColumns.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().getColumnName())),
.collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().getColumnName())),
statistics);

if (joinQuery.isEmpty()) {
Expand Down
Loading