Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
Expand Down Expand Up @@ -494,7 +493,7 @@ Optional<JoinApplicationResult<TableHandle>> applyJoin(
JoinType joinType,
TableHandle left,
TableHandle right,
List<JoinCondition> joinConditions,
ConnectorExpression joinCondition,
Map<String, ColumnHandle> leftAssignments,
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import io.trino.spi.connector.Constraint;
import io.trino.spi.connector.ConstraintApplicationResult;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.LimitApplicationResult;
Expand Down Expand Up @@ -1619,7 +1618,7 @@ public Optional<JoinApplicationResult<TableHandle>> applyJoin(
JoinType joinType,
TableHandle left,
TableHandle right,
List<JoinCondition> joinConditions,
ConnectorExpression joinCondition,
Map<String, ColumnHandle> leftAssignments,
Map<String, ColumnHandle> rightAssignments,
JoinStatistics statistics)
Expand All @@ -1640,7 +1639,7 @@ public Optional<JoinApplicationResult<TableHandle>> applyJoin(
joinType,
left.getConnectorHandle(),
right.getConnectorHandle(),
joinConditions,
joinCondition,
leftAssignments,
rightAssignments,
statistics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.Session;
import io.trino.metadata.LiteralFunction;
import io.trino.metadata.ResolvedFunction;
import io.trino.plugin.base.expression.ConnectorExpressions;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
Expand Down Expand Up @@ -71,6 +72,7 @@
import io.trino.type.Re2JRegexp;
import io.trino.type.Re2JRegexpType;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -105,6 +107,8 @@
import static io.trino.spi.expression.StandardFunctions.SUBTRACT_FUNCTION_NAME;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.sql.ExpressionUtils.combineConjuncts;
import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature;
Expand All @@ -125,12 +129,42 @@ public static Expression translate(Session session, ConnectorExpression expressi
.orElseThrow(() -> new UnsupportedOperationException("Expression is not supported: " + expression.toString()));
}

public static Optional<ConnectorExpression> translate(Session session, Expression expression, TypeAnalyzer types, TypeProvider inputTypes, PlannerContext plannerContext)
Comment thread
findepi marked this conversation as resolved.
Outdated
public static Optional<ConnectorExpression> translate(Session session, Expression expression, TypeProvider types, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
{
return new SqlToConnectorExpressionTranslator(session, types.getTypes(session, inputTypes, expression), plannerContext)
return new SqlToConnectorExpressionTranslator(session, typeAnalyzer.getTypes(session, types, expression), plannerContext)
.process(expression);
}

public static ConnectorExpressionTranslation translateConjuncts(
Session session,
Expression expression,
TypeProvider types,
PlannerContext plannerContext,
TypeAnalyzer typeAnalyzer)
{
Map<NodeRef<Expression>, Type> remainingExpressionTypes = typeAnalyzer.getTypes(session, types, expression);
ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator translator = new ConnectorExpressionTranslator.SqlToConnectorExpressionTranslator(
session,
remainingExpressionTypes,
plannerContext);

List<Expression> conjuncts = extractConjuncts(expression);
List<Expression> remaining = new ArrayList<>();
List<ConnectorExpression> converted = new ArrayList<>(conjuncts.size());
for (Expression conjunct : conjuncts) {
Optional<ConnectorExpression> connectorExpression = translator.process(conjunct);
if (connectorExpression.isPresent()) {
converted.add(connectorExpression.get());
}
else {
remaining.add(conjunct);
}
}
return new ConnectorExpressionTranslation(
ConnectorExpressions.and(converted),
combineConjuncts(plannerContext.getMetadata(), remaining));
}

@VisibleForTesting
static FunctionName functionNameForComparisonOperator(ComparisonExpression.Operator operator)
{
Expand All @@ -157,6 +191,15 @@ static FunctionName functionNameForArithmeticBinaryOperator(ArithmeticBinaryExpr
};
}

public record ConnectorExpressionTranslation(ConnectorExpression connectorExpression, Expression remainingExpression)
{
public ConnectorExpressionTranslation
{
requireNonNull(connectorExpression, "connectorExpression is null");
requireNonNull(remainingExpression, "remainingExpression is null");
}
}

private static class ConnectorToSqlExpressionTranslator
{
private final Session session;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ public PlanOptimizers(
.addAll(pushIntoTableScanRulesExceptJoins)
// PushJoinIntoTableScan must run after ReorderJoins (and DetermineJoinDistributionType)
// otherwise too early pushdown could prevent optimal plan from being selected.
.add(new PushJoinIntoTableScan(metadata))
.add(new PushJoinIntoTableScan(plannerContext, typeAnalyzer))
// DetermineTableScanNodePartitioning is needed to needs to ensure all table handles have proper partitioning determined
// Must run before AddExchanges
.add(new DetermineTableScanNodePartitioning(metadata, nodePartitioningManager, taskCountEstimator))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BasicRelationStatistics;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.ConnectorExpressionTranslator.ConnectorExpressionTranslation;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
Expand All @@ -43,14 +42,11 @@
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand All @@ -59,7 +55,6 @@
import static io.trino.matching.Capture.newCapture;
import static io.trino.spi.predicate.Domain.onlyNull;
import static io.trino.sql.ExpressionUtils.and;
import static io.trino.sql.ExpressionUtils.extractConjuncts;
import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown;
import static io.trino.sql.planner.plan.JoinNode.Type.FULL;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
Expand All @@ -81,11 +76,13 @@ public class PushJoinIntoTableScan
.with(left().matching(tableScan().capturedAs(LEFT_TABLE_SCAN)))
.with(right().matching(tableScan().capturedAs(RIGHT_TABLE_SCAN)));

private final Metadata metadata;
private final PlannerContext plannerContext;
private final TypeAnalyzer typeAnalyzer;

public PushJoinIntoTableScan(Metadata metadata)
public PushJoinIntoTableScan(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
{
this.metadata = requireNonNull(metadata, "metadata is null");
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
}

@Override
Expand Down Expand Up @@ -113,9 +110,14 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan");

Expression effectiveFilter = getEffectiveFilter(joinNode);
FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context);
ConnectorExpressionTranslation translation = ConnectorExpressionTranslator.translateConjuncts(
context.getSession(),
effectiveFilter,
context.getSymbolAllocator().getTypes(),
plannerContext,
typeAnalyzer);

if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
if (!translation.remainingExpression().equals(BooleanLiteral.TRUE_LITERAL)) {
// TODO add extra filter node above join
return Result.empty();
}
Expand Down Expand Up @@ -144,13 +146,13 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
*/
JoinStatistics joinStatistics = getJoinStatistics(joinNode, left, right, context);

Optional<JoinApplicationResult<TableHandle>> joinApplicationResult = metadata.applyJoin(
Optional<JoinApplicationResult<TableHandle>> joinApplicationResult = plannerContext.getMetadata().applyJoin(
context.getSession(),
getJoinType(joinNode),
left.getTable(),
right.getTable(),
filterSplitResult.getPushableConditions(),
// TODO we could pass only subset of assignments here, those which are needed to resolve filterSplitResult.getPushableConditions
translation.connectorExpression(),
// TODO we could pass only subset of assignments here, those which are needed to resolve translation.getPushableConditions
leftAssignments,
rightAssignments,
joinStatistics);
Expand Down Expand Up @@ -254,88 +256,6 @@ public Expression getEffectiveFilter(JoinNode node)
return effectiveFilter;
}

private FilterSplitResult splitFilter(Expression filter, List<Symbol> leftSymbolsList, List<Symbol> rightSymbolsList, Context context)
{
Set<Symbol> leftSymbols = ImmutableSet.copyOf(leftSymbolsList);
Set<Symbol> rightSymbols = ImmutableSet.copyOf(rightSymbolsList);

ImmutableList.Builder<JoinCondition> comparisonConditions = ImmutableList.builder();
ImmutableList.Builder<Expression> remainingConjuncts = ImmutableList.builder();

for (Expression conjunct : extractConjuncts(filter)) {
getPushableJoinCondition(conjunct, leftSymbols, rightSymbols, context)
.ifPresentOrElse(comparisonConditions::add, () -> remainingConjuncts.add(conjunct));
}

return new FilterSplitResult(comparisonConditions.build(), ExpressionUtils.and(remainingConjuncts.build()));
}

private Optional<JoinCondition> getPushableJoinCondition(Expression conjunct, Set<Symbol> leftSymbols, Set<Symbol> rightSymbols, Context context)
{
if (!(conjunct instanceof ComparisonExpression)) {
return Optional.empty();
}
ComparisonExpression comparison = (ComparisonExpression) conjunct;

if (!(comparison.getLeft() instanceof SymbolReference) || !(comparison.getRight() instanceof SymbolReference)) {
return Optional.empty();
}
Symbol left = Symbol.from(comparison.getLeft());
Symbol right = Symbol.from(comparison.getRight());
ComparisonExpression.Operator operator = comparison.getOperator();

if (!leftSymbols.contains(left)) {
// lets try with flipped expression
Symbol tmp = left;
left = right;
right = tmp;
operator = operator.flip();
}

if (leftSymbols.contains(left) && rightSymbols.contains(right)) {
return Optional.of(new JoinCondition(
joinConditionOperator(operator),
new Variable(left.getName(), context.getSymbolAllocator().getTypes().get(left)),
new Variable(right.getName(), context.getSymbolAllocator().getTypes().get(right))));
}
return Optional.empty();
}

private static class FilterSplitResult
{
private final List<JoinCondition> pushableConditions;
private final Expression remainingFilter;

public FilterSplitResult(List<JoinCondition> pushableConditions, Expression remainingFilter)
{
this.pushableConditions = requireNonNull(pushableConditions, "pushableConditions is null");
this.remainingFilter = requireNonNull(remainingFilter, "remainingFilter is null");
}

public List<JoinCondition> getPushableConditions()
{
return pushableConditions;
}

public Expression getRemainingFilter()
{
return remainingFilter;
}
}

private JoinCondition.Operator joinConditionOperator(ComparisonExpression.Operator operator)
{
return switch (operator) {
case EQUAL -> JoinCondition.Operator.EQUAL;
case NOT_EQUAL -> JoinCondition.Operator.NOT_EQUAL;
case LESS_THAN -> JoinCondition.Operator.LESS_THAN;
case LESS_THAN_OR_EQUAL -> JoinCondition.Operator.LESS_THAN_OR_EQUAL;
case GREATER_THAN -> JoinCondition.Operator.GREATER_THAN;
case GREATER_THAN_OR_EQUAL -> JoinCondition.Operator.GREATER_THAN_OR_EQUAL;
case IS_DISTINCT_FROM -> JoinCondition.Operator.IS_DISTINCT_FROM;
};
}

private JoinType getJoinType(JoinNode joinNode)
{
return switch (joinNode.getType()) {
Expand Down
Loading