Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange;
import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -566,9 +565,8 @@ private PartitioningVariableAssignments assignPartitioningVariables(Partitioning
VariableReferenceExpression variable;
if (argumentBinding.isConstant()) {
ConstantExpression constant = argumentBinding.getConstant();
RowExpression expression = constant(constant.getValue(), constant.getType());
variable = symbolAllocator.newVariable("constant_partition", constant.getType());
constants.put(variable, expression);
constants.put(variable, constant);
}
else {
variable = argumentBinding.getVariableReference();
Expand Down Expand Up @@ -658,7 +656,7 @@ private TableFinishNode createTemporaryTableWrite(
.map(source -> {
Assignments.Builder assignments = Assignments.builder();
source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getName(), variable.getType())));
constantVariables.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol)));
constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable)));
return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
})
.collect(toImmutableList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.OrderingScheme;
import com.facebook.presto.sql.planner.SymbolsExtractor;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -39,7 +43,6 @@
import java.util.stream.Stream;

import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUniqueVariable;
import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs;
import static com.facebook.presto.sql.planner.iterative.rule.Util.transpose;
import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.dependsOn;
Expand Down Expand Up @@ -155,7 +158,7 @@ protected static Optional<WindowNode> pullWindowNodeAboveProjects(
.putAll(identitiesAsSymbolReferences(targetInputs))
.build();

if (!newTargetChildOutputs.containsAll(extractUniqueVariable(newAssignments.getExpressions(), context.getSymbolAllocator().getTypes()))) {
if (!newTargetChildOutputs.containsAll(extractUniqueVariable(newAssignments, context.getSymbolAllocator().getTypes()))) {
// Projection uses an output of the target -- can't move the target above this projection.
return Optional.empty();
}
Expand All @@ -173,6 +176,12 @@ protected static Optional<WindowNode> pullWindowNodeAboveProjects(
}
}

private static Set<VariableReferenceExpression> extractUniqueVariable(Assignments assignments, TypeProvider types)
{
Collection<RowExpression> expressions = assignments.getExpressions();
return SymbolsExtractor.extractUniqueVariable(expressions.stream().map(OriginalExpressionUtils::castToExpression).collect(toImmutableList()), types);
}

public static class MergeAdjacentWindowsOverProjects
extends ManipulateAdjacentWindowsOverProjects
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,11 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
PlanNode output = node;

// Create identity projections for all existing symbols
Assignments.Builder leftProjections = Assignments.builder();
leftProjections.putAll(identityAssignmentsAsSymbolReferences(node.getLeft()
.getOutputVariables()));
Assignments.Builder leftProjections = Assignments.builder()
.putAll(identityAssignmentsAsSymbolReferences(node.getLeft().getOutputVariables()));

Assignments.Builder rightProjections = Assignments.builder();
rightProjections.putAll(identityAssignmentsAsSymbolReferences(node.getRight()
.getOutputVariables()));
Assignments.Builder rightProjections = Assignments.builder()
.putAll(identityAssignmentsAsSymbolReferences(node.getRight().getOutputVariables()));

// Create new projections for the new join clauses
List<JoinNode.EquiJoinClause> equiJoinClauses = new ArrayList<>();
Expand Down Expand Up @@ -626,13 +624,11 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Expression
rightSource != node.getRight() ||
!areExpressionsEquivalent(newJoinPredicate, joinPredicate)) {
// Create identity projections for all existing symbols
Assignments.Builder leftProjections = Assignments.builder();
leftProjections.putAll(identityAssignmentsAsSymbolReferences(node.getLeft()
.getOutputVariables()));
Assignments.Builder leftProjections = Assignments.builder()
.putAll(identityAssignmentsAsSymbolReferences(node.getLeft().getOutputVariables()));

Assignments.Builder rightProjections = Assignments.builder();
rightProjections.putAll(identityAssignmentsAsSymbolReferences(node.getRight()
.getOutputVariables()));
Assignments.Builder rightProjections = Assignments.builder()
.putAll(identityAssignmentsAsSymbolReferences(node.getRight().getOutputVariables()));

leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ public Void visitProject(ProjectNode node, Set<VariableReferenceExpression> boun

Set<VariableReferenceExpression> inputs = createInputs(source, boundVariables);
for (RowExpression expression : node.getAssignments().getExpressions()) {
Set<VariableReferenceExpression> dependencies = SymbolsExtractor.extractUniqueVariable(expression, types);
Set<VariableReferenceExpression> dependencies;
if (isExpression(expression)) {
dependencies = SymbolsExtractor.extractUniqueVariable(castToExpression(expression), types);
}
else {
dependencies = SymbolsExtractor.extractUniqueVariable(expression);
}
checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs);
}

Expand Down Expand Up @@ -675,7 +681,13 @@ public Void visitApply(ApplyNode node, Set<VariableReferenceExpression> boundVar
.build();

for (RowExpression expression : node.getSubqueryAssignments().getExpressions()) {
Set<VariableReferenceExpression> dependencies = SymbolsExtractor.extractUniqueVariable(expression, types);
Set<VariableReferenceExpression> dependencies;
if (isExpression(expression)) {
dependencies = SymbolsExtractor.extractUniqueVariable(castToExpression(expression), types);
}
else {
dependencies = SymbolsExtractor.extractUniqueVariable(expression);
}
checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs);
}

Expand Down Expand Up @@ -715,6 +727,9 @@ private static ImmutableSet<VariableReferenceExpression> createInputs(PlanNode s

private void checkDependencies(Collection<VariableReferenceExpression> inputs, Collection<VariableReferenceExpression> required, String message, Object... parameters)
{
// If a variable can be assigned into another type directly, CAST is usually implicitly removed.
// For example, we can assign input VARCHAR(3) to output VARCHAR(5)
// the reference variable in the assignment will have type VARCHAR(5) while the input is VARCHAR(3).
for (VariableReferenceExpression target : required) {
checkArgument(
inputs.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanFragmenter;
import com.facebook.presto.sql.planner.PlanOptimizers;
import com.facebook.presto.sql.planner.RuleStatsRecorder;
import com.facebook.presto.sql.planner.SubPlan;
import com.facebook.presto.sql.planner.iterative.IterativeOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.TranslateExpressions;
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
import com.facebook.presto.sql.planner.sanity.PlanSanityChecker;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
Expand Down Expand Up @@ -911,4 +914,16 @@ private static List<TableScanNode> findTableScanNodes(PlanNode node)
.where(TableScanNode.class::isInstance)
.findAll();
}

public PlanOptimizer translateExpressions()
{
// Translate all OriginalExpression in planNodes to RowExpression so that we can do plan pattern asserting and printing on RowExpression only.
return new IterativeOptimizer(
new RuleStatsRecorder(),
getStatsCalculator(),
getCostCalculator(),
new ImmutableSet.Builder()
.addAll(new TranslateExpressions(getMetadata(), getSqlParser()).rules())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs;
import com.facebook.presto.sql.planner.optimizations.TranslateExpressions;
import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.tpch.TpchConnectorFactory;
Expand Down Expand Up @@ -139,7 +138,7 @@ protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPatte
sql,
ImmutableList.<PlanOptimizer>builder()
.addAll(optimizers)
.add(translateExpressions()).build(),
.add(queryRunner.translateExpressions()).build(), // To avoid assert plan failure not printing out plan (#12885)
stage,
WarningCollector.NOOP);
PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), actualPlan, pattern);
Expand Down Expand Up @@ -167,22 +166,11 @@ protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMat
queryRunner.getStatsCalculator(),
queryRunner.getCostCalculator(),
ImmutableSet.of(new RemoveRedundantIdentityProjections())),
translateExpressions());
queryRunner.translateExpressions()); // To avoid assert plan failure not printing out plan (#12885)

assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers);
}

private PlanOptimizer translateExpressions()
{
return new IterativeOptimizer(
new RuleStatsRecorder(),
queryRunner.getStatsCalculator(),
queryRunner.getCostCalculator(),
new ImmutableSet.Builder()
.addAll(new TranslateExpressions(queryRunner.getMetadata(), queryRunner.getSqlParser()).rules())
.build());
}

protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean forceSingleNode, PlanMatchPattern pattern)
{
queryRunner.inTransaction(session, transactionSession -> {
Expand Down