From dabc453721e4ceb15304b3bf6a06c96614a0929c Mon Sep 17 00:00:00 2001 From: James Sun Date: Sun, 21 Apr 2019 23:35:54 -0700 Subject: [PATCH 1/6] Move helper Assignments helper functions to AssignmentsUtils --- .../presto/sql/planner/AssignmentsUtils.java | 180 ++++++++++++++++++ .../sql/planner/LocalExecutionPlanner.java | 2 +- .../presto/sql/planner/LogicalPlanner.java | 3 +- .../presto/sql/planner/PlanBuilder.java | 3 +- .../presto/sql/planner/PlanFragmenter.java | 3 +- .../presto/sql/planner/QueryPlanner.java | 10 +- .../presto/sql/planner/RelationPlanner.java | 11 +- .../presto/sql/planner/SubqueryPlanner.java | 12 +- .../iterative/rule/EliminateCrossJoins.java | 4 +- .../rule/ExpressionRewriteRuleSet.java | 5 +- .../iterative/rule/ExtractSpatialJoins.java | 6 +- .../iterative/rule/GatherAndMergeWindows.java | 5 +- .../rule/ImplementFilteredAggregations.java | 4 +- .../iterative/rule/InlineProjections.java | 9 +- .../iterative/rule/PruneProjectColumns.java | 3 +- .../rule/PushAggregationThroughOuterJoin.java | 4 +- ...PushPartialAggregationThroughExchange.java | 4 +- .../rule/PushProjectionThroughExchange.java | 4 +- .../rule/PushProjectionThroughUnion.java | 4 +- ...RewriteSpatialPartitioningAggregation.java | 4 +- .../TransformCorrelatedInPredicateToJoin.java | 7 +- .../TransformCorrelatedScalarSubquery.java | 4 +- ...mCorrelatedSingleRowSubqueryToProject.java | 3 +- .../TransformExistsApplyToLateralNode.java | 8 +- .../sql/planner/iterative/rule/Util.java | 4 +- .../HashGenerationOptimizer.java | 6 +- .../ImplementIntersectAndExceptAsUnion.java | 6 +- .../optimizations/IndexJoinOptimizer.java | 6 +- .../OptimizeMixedDistinctAggregations.java | 6 +- .../optimizations/PlanNodeDecorrelator.java | 3 +- .../optimizations/PredicatePushDown.java | 12 +- .../PruneUnreferencedOutputs.java | 6 +- .../ScalarAggregationToJoinRewriter.java | 7 +- ...uantifiedComparisonApplyToLateralJoin.java | 5 +- .../UnaliasSymbolReferences.java | 3 +- .../presto/sql/planner/plan/Assignments.java | 148 -------------- .../presto/cost/TestCostCalculator.java | 4 +- ...tSimpleFilterProjectSemiJoinStatsRule.java | 6 +- .../TestEffectivePredicateExtractor.java | 3 +- .../presto/sql/planner/TestTypeValidator.java | 6 +- .../assertions/PlanMatchingVisitor.java | 4 +- .../iterative/TestIterativeOptimizer.java | 4 +- .../sql/planner/iterative/TestRuleIndex.java | 4 +- .../rule/TestAddIntermediateAggregations.java | 4 +- .../rule/TestEliminateCrossJoins.java | 4 +- .../rule/TestExpressionRewriteRuleSet.java | 10 +- .../iterative/rule/TestInlineProjections.java | 14 +- .../rule/TestMergeAdjacentWindows.java | 6 +- .../rule/TestPruneAggregationColumns.java | 4 +- .../TestPruneCountAggregationOverScalar.java | 4 +- .../rule/TestPruneCrossJoinColumns.java | 4 +- .../rule/TestPruneFilterColumns.java | 4 +- .../rule/TestPruneIndexSourceColumns.java | 4 +- .../iterative/rule/TestPruneJoinColumns.java | 6 +- .../iterative/rule/TestPruneLimitColumns.java | 4 +- .../rule/TestPruneMarkDistinctColumns.java | 10 +- .../rule/TestPruneProjectColumns.java | 10 +- .../rule/TestPruneSemiJoinColumns.java | 4 +- .../rule/TestPruneTableScanColumns.java | 6 +- .../iterative/rule/TestPruneTopNColumns.java | 4 +- .../rule/TestPruneValuesColumns.java | 6 +- .../rule/TestPruneWindowColumns.java | 4 +- .../TestPushAggregationThroughOuterJoin.java | 4 +- .../rule/TestPushLimitThroughProject.java | 6 +- .../TestPushProjectionThroughExchange.java | 12 +- .../rule/TestPushProjectionThroughUnion.java | 6 +- ...estRemoveUnreferencedScalarApplyNodes.java | 6 +- ...formCorrelatedScalarAggregationToJoin.java | 4 +- ...TestTransformCorrelatedScalarSubquery.java | 8 +- ...mCorrelatedSingleRowSubqueryToProject.java | 4 +- ...TestTransformExistsApplyToLateralJoin.java | 8 +- ...rrelatedInPredicateSubqueryToSemiJoin.java | 8 +- .../iterative/rule/test/TestRuleTester.java | 4 +- .../sql/planner/plan/TestAssingments.java | 3 +- .../sanity/TestVerifyOnlyOneOutputNode.java | 6 +- 75 files changed, 387 insertions(+), 349 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java new file mode 100644 index 0000000000000..97d4deb4b8cdd --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java @@ -0,0 +1,180 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.Maps; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collector; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Arrays.asList; + +/** + * Everything in this should be moved back to Assignments + */ +public class AssignmentsUtils +{ + private AssignmentsUtils() {} + + // Originally, the following functions are also static + public static Builder builder() + { + return new Builder(); + } + + public static Assignments identity(Symbol... symbols) + { + return identity(asList(symbols)); + } + + public static Assignments identity(Iterable symbols) + { + return builder().putIdentities(symbols).build(); + } + + public static Assignments copyOf(Map assignments) + { + return builder() + .putAll(assignments) + .build(); + } + + public static Assignments of() + { + return builder().build(); + } + + public static Assignments of(Symbol symbol, Expression expression) + { + return builder().put(symbol, expression).build(); + } + + public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2) + { + return builder().put(symbol1, expression1).put(symbol2, expression2).build(); + } + + // Originally, the following functions are not static move assignments as member variables + public static Assignments rewrite(Assignments assignments, ExpressionRewriter rewriter) + { + return rewrite(assignments, expression -> ExpressionTreeRewriter.rewriteWith(rewriter, expression)); + } + + public static Assignments rewrite(Assignments assignments, Function rewrite) + { + return assignments.entrySet().stream() + .map(entry -> Maps.immutableEntry(entry.getKey(), rewrite.apply(entry.getValue()))) + .collect(toAssignments()); + } + + public static Assignments filter(Assignments assignments, Collection symbols) + { + return filter(assignments, symbols::contains); + } + + public static Assignments filter(Assignments assignments, Predicate predicate) + { + return assignments.entrySet().stream() + .filter(entry -> predicate.test(entry.getKey())) + .collect(toAssignments()); + } + + public static boolean isIdentity(Assignments assignments, Symbol output) + { + Expression expression = assignments.get(output); + + return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); + } + + private static Collector, Builder, Assignments> toAssignments() + { + return Collector.of( + AssignmentsUtils::builder, + (builder, entry) -> builder.put(entry.getKey(), entry.getValue()), + (left, right) -> { + left.putAll(right.build()); + return left; + }, + Builder::build); + } + + // Originally, the following class is also static + public static class Builder + { + private final Map assignments = new LinkedHashMap<>(); + + public Builder putAll(Assignments assignments) + { + return putAll(assignments.getMap()); + } + + public Builder putAll(Map assignments) + { + for (Map.Entry assignment : assignments.entrySet()) { + put(assignment.getKey(), assignment.getValue()); + } + return this; + } + + public Builder put(Symbol symbol, Expression expression) + { + if (assignments.containsKey(symbol)) { + Expression assignment = assignments.get(symbol); + checkState( + assignment.equals(expression), + "Symbol %s already has assignment %s, while adding %s", + symbol, + assignment, + expression); + } + assignments.put(symbol, expression); + return this; + } + + public Builder put(Map.Entry assignment) + { + put(assignment.getKey(), assignment.getValue()); + return this; + } + + public Builder putIdentities(Iterable symbols) + { + for (Symbol symbol : symbols) { + putIdentity(symbol); + } + return this; + } + + public Builder putIdentity(Symbol symbol) + { + put(symbol, symbol.toSymbolReference()); + return this; + } + + public Assignments build() + { + return new Assignments(assignments); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index b1c5f4dd93f6a..c0384e1374265 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -1143,7 +1143,7 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext RowExpression filterExpression = node.getPredicate(); List outputSymbols = node.getOutputSymbols(); - return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputSymbols), outputSymbols); + return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), AssignmentsUtils.identity(outputSymbols), outputSymbols); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 386ac75416988..505a3a7d42a90 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -44,7 +44,6 @@ import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -346,7 +345,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) RelationPlan plan = createRelationPlan(analysis, insertStatement.getQuery()); Map columns = metadata.getColumnHandles(session, insert.getTarget()); - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); for (ColumnMetadata column : tableMetadata.getColumns()) { if (column.isHidden()) { continue; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index f4b476c9578ca..48b9a7933f6f3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.sql.analyzer.Analysis; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -93,7 +92,7 @@ public PlanBuilder appendProjections(Iterable expressions, SymbolAll { TranslationMap translations = copyTranslations(); - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); // add an identity projection for underlying plan for (Symbol symbol : getRoot().getOutputSymbols()) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index 3194f1ab0ad35..bd099bba56411 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -45,7 +45,6 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -593,7 +592,7 @@ private TableFinishNode createTemporaryTableWrite( // update sources sources = sources.stream() .map(source -> { - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); assignments.putIdentities(source.getOutputSymbols()); constantSymbols.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol))); return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 35f56a0c6bd8b..ad457368d571f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -332,7 +332,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression { TranslationMap outputTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); for (Expression expression : expressions) { if (expression instanceof SymbolReference) { Symbol symbol = Symbol.from(expression); @@ -379,7 +379,7 @@ private Map coerce(Iterable expression private PlanBuilder explicitCoercionFields(PlanBuilder subPlan, Iterable alreadyCoerced, Iterable uncoerced) { TranslationMap translations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap); - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); projections.putAll(coerce(uncoerced, subPlan, translations)); @@ -409,7 +409,7 @@ private PlanBuilder explicitCoercionSymbols(PlanBuilder subPlan, Iterable assignments.put(key, value.toSymbolReference())); @@ -661,7 +661,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica TranslationMap newTranslations = subPlan.copyTranslations(); - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); projections.putIdentities(subPlan.getRoot().getOutputSymbols()); List> descriptor = groupingSets.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 641f5e1f05760..4f0457d3a275b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -31,7 +31,6 @@ import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExceptNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.IntersectNode; @@ -175,7 +174,7 @@ protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context) if (node.getColumnNames() != null) { ImmutableList.Builder newMappings = ImmutableList.builder(); - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); // project only the visible columns from the underlying relation for (int i = 0; i < subPlan.getDescriptor().getAllFieldCount(); i++) { @@ -429,8 +428,8 @@ If casts are redundant (due to column type and common type being equal), Map leftJoinColumns = new HashMap<>(); Map rightJoinColumns = new HashMap<>(); - Assignments.Builder leftCoercions = Assignments.builder(); - Assignments.Builder rightCoercions = Assignments.builder(); + AssignmentsUtils.Builder leftCoercions = AssignmentsUtils.builder(); + AssignmentsUtils.Builder rightCoercions = AssignmentsUtils.builder(); leftCoercions.putIdentities(left.getRoot().getOutputSymbols()); rightCoercions.putIdentities(right.getRoot().getOutputSymbols()); @@ -481,7 +480,7 @@ If casts are redundant (due to column type and common type being equal), // Add a projection to produce the outputs of the columns in the USING clause, // which are defined as coalesce(l.k, r.k) - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); ImmutableList.Builder outputs = ImmutableList.builder(); for (Identifier column : joinColumns) { @@ -723,7 +722,7 @@ private RelationPlan addCoercions(RelationPlan plan, Type[] targetColumnTypes) verify(targetColumnTypes.length == oldSymbols.size()); ImmutableList.Builder newSymbols = new ImmutableList.Builder<>(); Field[] newFields = new Field[targetColumnTypes.length]; - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); for (int i = 0; i < targetColumnTypes.length; i++) { Symbol inputSymbol = oldSymbols.get(i); Type inputType = symbolAllocator.getTypes().get(inputSymbol); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index 0b2bf4fd26f15..4c1eb93b7cf4e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -198,7 +198,7 @@ private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate subPlan.getTranslations().put(inPredicate, inPredicateSubquerySymbol); - return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubquerySymbol, inPredicateSubqueryExpression), correlationAllowed); + return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), AssignmentsUtils.of(inPredicateSubquerySymbol, inPredicateSubqueryExpression), correlationAllowed); } private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set scalarSubqueries, boolean correlationAllowed) @@ -288,7 +288,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred } // add an explicit projection that removes all columns - PlanNode subqueryNode = new ProjectNode(idAllocator.getNextId(), subqueryPlan.getRoot(), Assignments.of()); + PlanNode subqueryNode = new ProjectNode(idAllocator.getNextId(), subqueryPlan.getRoot(), AssignmentsUtils.of()); Symbol exists = symbolAllocator.newSymbol("exists", BOOLEAN); subPlan.getTranslations().put(existsPredicate, exists); @@ -297,7 +297,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred subPlan, existsPredicate.getSubquery(), subqueryNode, - Assignments.of(exists, rewrittenExistsPredicate), + AssignmentsUtils.of(exists, rewrittenExistsPredicate), correlationAllowed); } @@ -395,7 +395,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa subPlan, quantifiedComparison.getSubquery(), subqueryPlan.getRoot(), - Assignments.of(coercedQuantifiedComparisonSymbol, coercedQuantifiedComparison), + AssignmentsUtils.of(coercedQuantifiedComparisonSymbol, coercedQuantifiedComparison), correlationAllowed); } @@ -570,8 +570,8 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) { ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node); - Assignments assignments = rewrittenNode.getAssignments() - .rewrite(expression -> replaceExpression(expression, mapping)); + Assignments assignments = AssignmentsUtils + .rewrite(rewrittenNode.getAssignments(), expression -> replaceExpression(expression, mapping)); return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 0fae6fe190ee6..e3f112d56ee58 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -17,11 +17,11 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -203,7 +203,7 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra result = new ProjectNode( idAllocator.getNextId(), result, - Assignments.copyOf(graph.getAssignments().get())); + AssignmentsUtils.copyOf(graph.getAssignments().get())); } // If needed, introduce a projection to constrain the outputs to what was originally expected diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index c8b0335705608..5a6b4738823e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -122,7 +123,7 @@ public Pattern getPattern() @Override public Result apply(ProjectNode projectNode, Captures captures, Context context) { - Assignments assignments = projectNode.getAssignments().rewrite(x -> rewriter.rewrite(x, context)); + Assignments assignments = AssignmentsUtils.rewrite(projectNode.getAssignments(), x -> rewriter.rewrite(x, context)); if (projectNode.getAssignments().equals(assignments)) { return Result.empty(); } @@ -311,7 +312,7 @@ public Pattern getPattern() @Override public Result apply(ApplyNode applyNode, Captures captures, Context context) { - Assignments subqueryAssignments = applyNode.getSubqueryAssignments().rewrite(x -> rewriter.rewrite(x, context)); + Assignments subqueryAssignments = AssignmentsUtils.rewrite(applyNode.getSubqueryAssignments(), x -> rewriter.rewrite(x, context)); if (applyNode.getSubqueryAssignments().equals(subqueryAssignments)) { return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index abc58c97cd884..c0d8c684e000a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -39,11 +39,11 @@ import com.facebook.presto.split.SplitSource; import com.facebook.presto.split.SplitSource.SplitBatch; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.Rule.Context; import com.facebook.presto.sql.planner.iterative.Rule.Result; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -584,7 +584,7 @@ private static Optional newRadiusSymbol(Context context, Expression expr private static PlanNode addProjection(Context context, PlanNode node, Symbol symbol, Expression expression) { - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); for (Symbol outputSymbol : node.getOutputSymbols()) { projections.putIdentity(outputSymbol); } @@ -595,7 +595,7 @@ private static PlanNode addProjection(Context context, PlanNode node, Symbol sym private static PlanNode addPartitioningNodes(Context context, PlanNode node, Symbol partitionSymbol, KdbTree kdbTree, Expression geometry, Optional radius) { - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); for (Symbol outputSymbol : node.getOutputSymbols()) { projections.putIdentity(outputSymbol); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index b4aa55db2c23e..4c80553b961b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -17,6 +17,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.PropertyPattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; @@ -140,14 +141,14 @@ protected static Optional pullWindowNodeAboveProjects( // The target node, when hoisted above the projections, will provide the symbols directly. Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( project.getAssignments().getMap(), - output -> !(project.getAssignments().isIdentity(output) && targetOutputs.contains(output))); + output -> !(AssignmentsUtils.isIdentity(project.getAssignments(), output) && targetOutputs.contains(output))); if (targetInputs.stream().anyMatch(assignmentsWithoutTargetOutputIdentities::containsKey)) { // Redefinition of an input to the target -- can't handle this case. return Optional.empty(); } - Assignments newAssignments = Assignments.builder() + Assignments newAssignments = AssignmentsUtils.builder() .putAll(assignmentsWithoutTargetOutputIdentities) .putIdentities(targetInputs) .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index de74a7cd800da..1f85ddba6b2df 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -15,11 +15,11 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -81,7 +81,7 @@ public Pattern getPattern() @Override public Result apply(AggregationNode aggregation, Captures captures, Context context) { - Assignments.Builder newAssignments = Assignments.builder(); + AssignmentsUtils.Builder newAssignments = AssignmentsUtils.builder(); ImmutableMap.Builder aggregations = ImmutableMap.builder(); ImmutableList.Builder maskSymbols = ImmutableList.builder(); boolean aggregateWithoutFilterPresent = false; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index 56b1239d2121b..f11491ca65ab7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Rule; @@ -70,7 +71,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) } // inline the expressions - Assignments assignments = child.getAssignments().filter(targets::contains); + Assignments assignments = AssignmentsUtils.filter(child.getAssignments(), targets::contains); Map parentAssignments = parent.getAssignments() .entrySet().stream() .collect(Collectors.toMap( @@ -88,7 +89,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) .collect(toSet()); - Assignments.Builder childAssignments = Assignments.builder(); + AssignmentsUtils.Builder childAssignments = AssignmentsUtils.builder(); for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { childAssignments.put(assignment); @@ -105,7 +106,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) child.getId(), child.getSource(), childAssignments.build()), - Assignments.copyOf(parentAssignments))); + AssignmentsUtils.copyOf(parentAssignments))); } private Expression inlineReferences(Expression expression, Assignments assignments) @@ -155,7 +156,7 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN Set singletons = dependencies.entrySet().stream() .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics - .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + .filter(entry -> !AssignmentsUtils.isIdentity(child.getAssignments(), entry.getKey())) // skip identities, otherwise, this rule will keep firing forever .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java index 7e1f855310c9c..e702173a82474 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneProjectColumns.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -41,6 +42,6 @@ protected Optional pushDownProjectOff( new ProjectNode( childProjectNode.getId(), childProjectNode.getSource(), - childProjectNode.getAssignments().filter(referencedOutputs))); + AssignmentsUtils.filter(childProjectNode.getAssignments(), referencedOutputs))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 1c45f2885e280..ba6f9d7475137 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -18,13 +18,13 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -252,7 +252,7 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati Optional.empty()); // Add coalesce expressions for all aggregation functions - Assignments.Builder assignmentsBuilder = Assignments.builder(); + AssignmentsUtils.Builder assignmentsBuilder = AssignmentsUtils.builder(); for (Symbol symbol : outerJoin.getOutputSymbols()) { if (aggregationNode.getAggregations().containsKey(symbol)) { assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 9888e4bdcf6a4..f51954689ef3f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -19,13 +19,13 @@ import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.SymbolMapper; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -163,7 +163,7 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, SymbolMapper symbolMapper = mappingsBuilder.build(); AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator()); - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); for (Symbol output : aggregation.getOutputSymbols()) { Symbol input = symbolMapper.map(output); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 600ac6ab241da..58ea2e8857e7f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -17,10 +17,10 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -89,7 +89,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) for (int i = 0; i < exchange.getSources().size(); i++) { Map outputToInputMap = extractExchangeOutputToInput(exchange, i); - Assignments.Builder projections = Assignments.builder(); + AssignmentsUtils.Builder projections = AssignmentsUtils.builder(); ImmutableList.Builder inputs = ImmutableList.builder(); // Need to retain the partition keys for the exchange diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index f58dcd21f1aa4..d6a9b6b079934 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -17,9 +17,9 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.UnionNode; @@ -69,7 +69,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) for (int i = 0; i < source.getSources().size(); i++) { Map outputToInput = Maps.transformValues(source.sourceSymbolMap(i), Symbol::toSymbolReference); // Map: output of union -> input of this source to the union - Assignments.Builder assignments = Assignments.builder(); // assignments for the new ProjectNode + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); // assignments for the new ProjectNode // mapping from current ProjectNode to new ProjectNode, used to identify the output layout Map projectSymbolMapping = new HashMap<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index 5b983842dac78..8aaf9c855c9be 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -18,11 +18,11 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -121,7 +121,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) new ProjectNode( context.getIdAllocator().getNextId(), node.getSource(), - Assignments.builder() + AssignmentsUtils.builder() .putIdentities(node.getSource().getOutputSymbols()) .put(partitionCountSymbol, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) .putAll(envelopeAssignments.build()) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 775e2a8db122d..9073f707f611f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -174,7 +175,7 @@ private PlanNode buildInPredicateEquivalent( ProjectNode buildSide = new ProjectNode( idAllocator.getNextId(), decorrelatedBuildSource, - Assignments.builder() + AssignmentsUtils.builder() .putIdentities(decorrelatedBuildSource.getOutputSymbols()) .put(buildSideKnownNonNull, bigint(0)) .build()); @@ -224,7 +225,7 @@ private PlanNode buildInPredicateEquivalent( return new ProjectNode( idAllocator.getNextId(), aggregation, - Assignments.builder() + AssignmentsUtils.builder() .putIdentities(apply.getInput().getOutputSymbols()) .put(inPredicateOutputSymbol, inPredicateEquivalent) .build()); @@ -322,7 +323,7 @@ public Optional visitProject(ProjectNode node, PlanNode reference) Optional result = decorrelate(node.getSource()); return result.map(decorrelated -> { - Assignments.Builder assignments = Assignments.builder() + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder() .putAll(node.getAssignments()); // Pull up all symbols used by a filter (except correlation) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index f25319afec143..408d681ca7e29 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -17,10 +17,10 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AssignUniqueId; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -154,6 +154,6 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context return Result.ofPlanNode(new ProjectNode( context.getIdAllocator().getNextId(), filterNode, - Assignments.identity(lateralJoinNode.getOutputSymbols()))); + AssignmentsUtils.identity(lateralJoinNode.getOutputSymbols()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index ec83a14e80039..8518e383ebd64 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -15,6 +15,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -75,7 +76,7 @@ public Result apply(LateralJoinNode parent, Captures captures, Context context) return Result.ofPlanNode(parent.getInput()); } else if (subqueryProjections.size() == 1) { - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .putIdentities(parent.getInput().getOutputSymbols()) .putAll(subqueryProjections.get(0).getAssignments()) .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 1ee7f1d3e46bc..017de469904b0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -16,13 +16,13 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -119,7 +119,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C Symbol exists = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()); Symbol subqueryTrue = context.getSymbolAllocator().newSymbol("subqueryTrue", BOOLEAN); - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); assignments.putIdentities(applyNode.getInput().getOutputSymbols()); assignments.put(exists, new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL))); @@ -130,7 +130,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C applyNode.getSubquery(), 1L, false), - Assignments.of(subqueryTrue, TRUE_LITERAL)); + AssignmentsUtils.of(subqueryTrue, TRUE_LITERAL)); PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()); if (!decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isPresent()) { @@ -170,7 +170,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), - Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), + AssignmentsUtils.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), parent.getCorrelation(), INNER, parent.getOriginSubqueryError()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java index 240d81dd73cad..60e24816082c0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -82,7 +82,7 @@ public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator new ProjectNode( idAllocator.getNextId(), node, - Assignments.identity(restrictedOutputs))); + AssignmentsUtils.identity(restrictedOutputs))); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 43dbc343aed49..1440f270ee333 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -18,6 +18,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -26,7 +27,6 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -610,7 +610,7 @@ public PlanWithProperties visitProject(ProjectNode node, HashComputationSet pare PlanWithProperties child = plan(node.getSource(), sourceContext); // create a new project node with all assignments from the original node - Assignments.Builder newAssignments = Assignments.builder(); + AssignmentsUtils.Builder newAssignments = AssignmentsUtils.builder(); newAssignments.putAll(node.getAssignments()); // and all hash symbols that could be translated to the source symbols @@ -713,7 +713,7 @@ private PlanWithProperties planAndEnforce( private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashComputationSet requiredHashes) { - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); Map outputHashSymbols = new HashMap<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 7981203360f9d..26e58f71c9bcc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -19,13 +19,13 @@ import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExceptNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.IntersectNode; @@ -213,7 +213,7 @@ private List appendMarkers(List markers, List nodes, private PlanNode appendMarkers(PlanNode source, int markerIndex, List markers, Map projections) { - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); // add existing intersect symbols to projection for (Map.Entry entry : projections.entrySet()) { Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey())); @@ -288,7 +288,7 @@ private ProjectNode project(PlanNode node, List columns) return new ProjectNode( idAllocator.getNextId(), node, - Assignments.identity(columns)); + AssignmentsUtils.identity(columns)); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index 5e6a8412ce978..b5fee93f6f588 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -19,6 +19,7 @@ import com.facebook.presto.metadata.ResolvedIndex; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; import com.facebook.presto.sql.planner.LiteralEncoder; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -26,7 +27,6 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; @@ -161,7 +161,7 @@ else if (leftIndexCandidate.isPresent()) { indexJoinNode = new ProjectNode( idAllocator.getNextId(), indexJoinNode, - Assignments.identity(node.getOutputSymbols())); + AssignmentsUtils.identity(node.getOutputSymbols())); } return indexJoinNode; @@ -203,7 +203,7 @@ private static PlanNode createIndexJoinWithExpectedOutputs(List expected result = new ProjectNode( idAllocator.getNextId(), result, - Assignments.identity(expectedOutputs)); + AssignmentsUtils.identity(expectedOutputs)); } return result; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index f97ce74878088..6086e10615395 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -19,13 +19,13 @@ import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -200,7 +200,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext groupBySymbols, Map aggregationOutputSymbolsMap) { - Assignments.Builder outputSymbols = Assignments.builder(); + AssignmentsUtils.Builder outputSymbols = AssignmentsUtils.builder(); ImmutableMap.Builder outputNonDistinctAggregateSymbols = ImmutableMap.builder(); for (Symbol symbol : source.getOutputSymbols()) { if (distinctSymbol.equals(symbol)) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 08b8f81c25c10..56a83ccb123c4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -15,6 +15,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; @@ -259,7 +260,7 @@ public Optional visitProject(ProjectNode node, Void context .filter(symbol -> !nodeOutputSymbols.contains(symbol)) .collect(toImmutableList()); - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .putAll(node.getAssignments()) .putIdentities(symbolsToAdd) .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index 7df2da4623f59..02ce4cf39d3ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.EffectivePredicateExtractor; import com.facebook.presto.sql.planner.EqualityInference; @@ -32,7 +33,6 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -453,12 +453,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode output = node; // Create identity projections for all existing symbols - Assignments.Builder leftProjections = Assignments.builder(); + AssignmentsUtils.Builder leftProjections = AssignmentsUtils.builder(); leftProjections.putAll(node.getLeft() .getOutputSymbols().stream() .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); - Assignments.Builder rightProjections = Assignments.builder(); + AssignmentsUtils.Builder rightProjections = AssignmentsUtils.builder(); rightProjections.putAll(node.getRight() .getOutputSymbols().stream() .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); @@ -537,7 +537,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } if (!node.getOutputSymbols().equals(output.getOutputSymbols())) { - output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputSymbols())); + output = new ProjectNode(idAllocator.getNextId(), output, AssignmentsUtils.identity(node.getOutputSymbols())); } return output; @@ -602,12 +602,12 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext key, Symbol::toSymbolReference))); - Assignments.Builder rightProjections = Assignments.builder(); + AssignmentsUtils.Builder rightProjections = AssignmentsUtils.builder(); rightProjections.putAll(node.getRight() .getOutputSymbols().stream() .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 71d172d057eb6..27c377c73146b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PartitioningScheme; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; @@ -27,7 +28,6 @@ import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.ExceptNode; @@ -523,7 +523,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext> conte { ImmutableSet.Builder expectedInputs = ImmutableSet.builder(); - Assignments.Builder builder = Assignments.builder(); + AssignmentsUtils.Builder builder = AssignmentsUtils.builder(); node.getAssignments().forEach((symbol, expression) -> { if (context.get().contains(symbol)) { expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); @@ -776,7 +776,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) // extract symbols required subquery plan ImmutableSet.Builder subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder(); - Assignments.Builder subqueryAssignments = Assignments.builder(); + AssignmentsUtils.Builder subqueryAssignments = AssignmentsUtils.builder(); for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { Symbol output = entry.getKey(); Expression expression = entry.getValue(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index c17bc9a43efc3..44f281e091f09 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -79,7 +80,7 @@ public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, Aggreg } Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); - Assignments scalarAggregationSourceAssignments = Assignments.builder() + Assignments scalarAggregationSourceAssignments = AssignmentsUtils.builder() .putIdentities(source.get().getNode().getOutputSymbols()) .put(nonNull, TRUE_LITERAL) .build(); @@ -140,7 +141,7 @@ private PlanNode rewriteScalarAggregation( List aggregationOutputSymbols = getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get()); if (subqueryProjection.isPresent()) { - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .putIdentities(aggregationOutputSymbols) .putAll(subqueryProjection.get().getAssignments()) .build(); @@ -154,7 +155,7 @@ private PlanNode rewriteScalarAggregation( return new ProjectNode( idAllocator.getNextId(), aggregationNode.get(), - Assignments.identity(aggregationOutputSymbols)); + AssignmentsUtils.identity(aggregationOutputSymbols)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index 9d0720764ed42..decebd029cfbc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -180,7 +181,7 @@ countNonNullValue, new Aggregation( Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().getSymbols()); - return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery)); + return projectExpressions(lateralJoinNode, AssignmentsUtils.of(quantifiedComparisonSymbol, valueComparedToSubquery)); } public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) @@ -255,7 +256,7 @@ private static boolean shouldCompareValueWithLowerBound(QuantifiedComparisonExpr private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) { - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .putIdentities(input.getOutputSymbols()) .putAll(subqueryAssignments) .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 45ead3d8f454f..98bf63beea718 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -598,7 +599,7 @@ private void map(Symbol symbol, Symbol canonical) private Assignments canonicalize(Assignments oldAssignments) { Map computedExpressions = new HashMap<>(); - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); for (Map.Entry entry : oldAssignments.getMap().entrySet()) { Expression expression = canonicalize(entry.getValue()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 749613cd80d8c..634e37e4baee4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -15,69 +15,22 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.ExpressionRewriter; -import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.sql.tree.SymbolReference; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; import java.util.Collection; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.function.BiConsumer; -import java.util.function.Function; -import java.util.stream.Collector; -import static com.google.common.base.Preconditions.checkState; -import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; public class Assignments { - public static Builder builder() - { - return new Builder(); - } - - public static Assignments identity(Symbol... symbols) - { - return identity(asList(symbols)); - } - - public static Assignments identity(Iterable symbols) - { - return builder().putIdentities(symbols).build(); - } - - public static Assignments copyOf(Map assignments) - { - return builder() - .putAll(assignments) - .build(); - } - - public static Assignments of() - { - return builder().build(); - } - - public static Assignments of(Symbol symbol, Expression expression) - { - return builder().put(symbol, expression).build(); - } - - public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2) - { - return builder().put(symbol1, expression1).put(symbol2, expression2).build(); - } - private final Map assignments; @JsonCreator @@ -97,49 +50,6 @@ public Map getMap() return assignments; } - public Assignments rewrite(ExpressionRewriter rewriter) - { - return rewrite(expression -> ExpressionTreeRewriter.rewriteWith(rewriter, expression)); - } - - public Assignments rewrite(Function rewrite) - { - return assignments.entrySet().stream() - .map(entry -> Maps.immutableEntry(entry.getKey(), rewrite.apply(entry.getValue()))) - .collect(toAssignments()); - } - - public Assignments filter(Collection symbols) - { - return filter(symbols::contains); - } - - public Assignments filter(Predicate predicate) - { - return assignments.entrySet().stream() - .filter(entry -> predicate.apply(entry.getKey())) - .collect(toAssignments()); - } - - public boolean isIdentity(Symbol output) - { - Expression expression = assignments.get(output); - - return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); - } - - private Collector, Builder, Assignments> toAssignments() - { - return Collector.of( - Assignments::builder, - (builder, entry) -> builder.put(entry.getKey(), entry.getValue()), - (left, right) -> { - left.putAll(right.build()); - return left; - }, - Assignments.Builder::build); - } - public Collection getExpressions() { return assignments.values(); @@ -195,62 +105,4 @@ public int hashCode() { return assignments.hashCode(); } - - public static class Builder - { - private final Map assignments = new LinkedHashMap<>(); - - public Builder putAll(Assignments assignments) - { - return putAll(assignments.getMap()); - } - - public Builder putAll(Map assignments) - { - for (Entry assignment : assignments.entrySet()) { - put(assignment.getKey(), assignment.getValue()); - } - return this; - } - - public Builder put(Symbol symbol, Expression expression) - { - if (assignments.containsKey(symbol)) { - Expression assignment = assignments.get(symbol); - checkState( - assignment.equals(expression), - "Symbol %s already has assignment %s, while adding %s", - symbol, - assignment, - expression); - } - assignments.put(symbol, expression); - return this; - } - - public Builder put(Entry assignment) - { - put(assignment.getKey(), assignment.getValue()); - return this; - } - - public Builder putIdentities(Iterable symbols) - { - for (Symbol symbol : symbols) { - putIdentity(symbol); - } - return this; - } - - public Builder putIdentity(Symbol symbol) - { - put(symbol, symbol.toSymbolReference()); - return this; - } - - public Assignments build() - { - return new Assignments(assignments); - } - } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 7ee8b02af4445..87ceeda2f9d48 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanFragmenter; @@ -40,7 +41,6 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -759,7 +759,7 @@ private PlanNode project(String id, PlanNode source, String symbol, Expression e return new ProjectNode( new PlanNodeId(id), source, - Assignments.of(new Symbol(symbol), expression)); + AssignmentsUtils.of(new Symbol(symbol), expression)); } private AggregationNode aggregation(String id, PlanNode source) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java index cfe4e0e633968..271ec64ad892e 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java @@ -15,8 +15,8 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.tree.Expression; @@ -126,9 +126,9 @@ public void testFilterPositiveNarrowingProjectSemiJoin(boolean toRowExpression) if (toRowExpression) { return pb.filter( TRANSLATOR.translateAndOptimize(expression("sjo"), pb.getTypes()), - pb.project(Assignments.identity(semiJoinOutput, a), semiJoinNode)); + pb.project(AssignmentsUtils.identity(semiJoinOutput, a), semiJoinNode)); } - return pb.filter(expression("sjo"), pb.project(Assignments.identity(semiJoinOutput, a), semiJoinNode)); + return pb.filter(expression("sjo"), pb.project(AssignmentsUtils.identity(semiJoinOutput, a), semiJoinNode)); }) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 6378fe851be2b..1fb94c5bbd81f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -25,7 +25,6 @@ import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; @@ -213,7 +212,7 @@ public void testProject() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - Assignments.of(D, AE, E, CE)); + AssignmentsUtils.of(D, AE, E, CE)); Expression effectivePredicate = effectivePredicateExtractor.extract(node); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index ec32e38d3aa9c..8a7bf680a36cd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -114,7 +114,7 @@ public void testValidProject() { Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); Expression expression2 = new Cast(columnC.toSymbolReference(), StandardTypes.BIGINT); - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .put(symbolAllocator.newSymbol(expression1, BIGINT), expression1) .put(symbolAllocator.newSymbol(expression2, BIGINT), expression2) .build(); @@ -201,7 +201,7 @@ public void testValidAggregation() public void testValidTypeOnlyCoercion() { Expression expression = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .put(symbolAllocator.newSymbol(expression, BIGINT), expression) .put(symbolAllocator.newSymbol(columnE.toSymbolReference(), VARCHAR), columnE.toSymbolReference()) // implicit coercion from varchar(3) to varchar .build(); @@ -215,7 +215,7 @@ public void testInvalidProject() { Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.INTEGER); Expression expression2 = new Cast(columnA.toSymbolReference(), StandardTypes.INTEGER); - Assignments assignments = Assignments.builder() + Assignments assignments = AssignmentsUtils.builder() .put(symbolAllocator.newSymbol(expression1, BIGINT), expression1) // should be INTEGER .put(symbolAllocator.newSymbol(expression1, INTEGER), expression2) .build(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java index 4911224f3fc20..b6bcf75b29016 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java @@ -16,10 +16,10 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanVisitor; @@ -63,7 +63,7 @@ public MatchResult visitExchange(ExchangeNode node, PlanMatchPattern pattern) SymbolAliases newAliases = result.getAliases(); for (List inputs : allInputs) { - Assignments.Builder assignments = Assignments.builder(); + AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); for (int i = 0; i < inputs.size(); ++i) { assignments.put(outputs.get(i), inputs.get(i).toSymbolReference()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java index cd16340416e61..044c7f12240ad 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java @@ -18,9 +18,9 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.testing.LocalQueryRunner; @@ -107,7 +107,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) if (isIdentityProjection(project)) { return Result.ofPlanNode(project.getSource()); } - PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, Assignments.identity(project.getOutputSymbols())); + PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, AssignmentsUtils.identity(project.getOutputSymbols())); return Result.ofPlanNode(projectNode); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleIndex.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleIndex.java index 8cbf04f2b36e9..e9039d01db831 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleIndex.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestRuleIndex.java @@ -16,9 +16,9 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -51,7 +51,7 @@ public void testWithPlanNodeHierarchy() .register(anyRule) .build(); - ProjectNode projectNode = planBuilder.project(Assignments.of(), planBuilder.values()); + ProjectNode projectNode = planBuilder.project(AssignmentsUtils.of(), planBuilder.values()); FilterNode filterNode = planBuilder.filter(BooleanLiteral.TRUE_LITERAL, planBuilder.values()); ValuesNode valuesNode = planBuilder.values(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 50a3d14687e34..5ce6956ec84d0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; @@ -316,7 +316,7 @@ public void testInterimProject() p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.project( - Assignments.identity(p.symbol("b")), + AssignmentsUtils.identity(p.symbol("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index dffc428870741..ffc165d6f2d7d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -250,7 +250,7 @@ private PlanNode projectNode(PlanNode source, String symbol, Expression expressi return new ProjectNode( idAllocator.getNextId(), source, - Assignments.of(new Symbol(symbol), expression)); + AssignmentsUtils.of(new Symbol(symbol), expression)); } private String symbol(String name) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index 5a420691bd7df..e59fe67881f8a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.type.DateType; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; @@ -58,7 +58,7 @@ public void testProjectionExpressionRewrite() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), PlanBuilder.expression("x IS NOT NULL")), + AssignmentsUtils.of(p.symbol("y"), PlanBuilder.expression("x IS NOT NULL")), p.values(p.symbol("x")))) .matches( project(ImmutableMap.of("y", expression("0")), values("x"))); @@ -69,7 +69,7 @@ public void testProjectionExpressionNotRewritten() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.symbol("y"), PlanBuilder.expression("0")), + AssignmentsUtils.of(p.symbol("y"), PlanBuilder.expression("0")), p.values(p.symbol("x")))) .doesNotFire(); } @@ -150,7 +150,7 @@ public void testApplyExpressionRewrite() { tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( - Assignments.of( + AssignmentsUtils.of( p.symbol("a", BIGINT), new InPredicate( new LongLiteral("1"), @@ -173,7 +173,7 @@ public void testApplyExpressionNotRewritten() { tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( - Assignments.of( + AssignmentsUtils.of( p.symbol("a", BIGINT), new InPredicate( new LongLiteral("0"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index efd8fcf6c5b20..d92b13231c8ed 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -33,7 +33,7 @@ public void test() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.builder() + AssignmentsUtils.builder() .put(p.symbol("identity"), expression("symbol")) // identity .put(p.symbol("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times .put(p.symbol("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times @@ -42,7 +42,7 @@ public void test() .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once .put(p.symbol("try"), expression("try(complex / literal)")) .build(), - p.project(Assignments.builder() + p.project(AssignmentsUtils.builder() .put(p.symbol("symbol"), expression("x")) .put(p.symbol("complex"), expression("x * 2")) .put(p.symbol("literal"), expression("1")) @@ -73,9 +73,9 @@ public void testIdentityProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.symbol("output"), expression("value")), + AssignmentsUtils.of(p.symbol("output"), expression("value")), p.project( - Assignments.identity(p.symbol("value")), + AssignmentsUtils.identity(p.symbol("value")), p.values(p.symbol("value"))))) .doesNotFire(); } @@ -86,9 +86,9 @@ public void testSubqueryProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.identity(p.symbol("fromOuterScope"), p.symbol("value")), + AssignmentsUtils.identity(p.symbol("fromOuterScope"), p.symbol("value")), p.project( - Assignments.identity(p.symbol("value")), + AssignmentsUtils.identity(p.symbol("value")), p.values(p.symbol("value"))))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 8c9fd78c911fa..b1965cf83c831 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -14,11 +14,11 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; @@ -181,12 +181,12 @@ public void testIntermediateProjectNodes() newWindowNodeSpecification(p, "a"), ImmutableMap.of(p.symbol("lagOutput"), newWindowNodeFunction("lag", windowA, "a", "one")), p.project( - Assignments.builder() + AssignmentsUtils.builder() .put(p.symbol("one"), expression("CAST(1 AS bigint)")) .putIdentities(ImmutableList.of(p.symbol("a"), p.symbol("avgOutput"))) .build(), p.project( - Assignments.identity(p.symbol("a"), p.symbol("avgOutput"), p.symbol("unused")), + AssignmentsUtils.identity(p.symbol("a"), p.symbol("avgOutput"), p.symbol("unused")), p.window( newWindowNodeSpecification(p, "a"), ImmutableMap.of(p.symbol("avgOutput"), newWindowNodeFunction("avg", windowA, "a")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java index 97f99717e4a52..b0dab234e5108 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -71,7 +71,7 @@ private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate Symbol b = planBuilder.symbol("b"); Symbol key = planBuilder.symbol("key"); return planBuilder.project( - Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), + AssignmentsUtils.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .source(planBuilder.values(key)) .singleGroupingSet(key) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 1db68352cc530..0dbf3b833f037 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -15,11 +15,11 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SymbolReference; @@ -142,7 +142,7 @@ public void testDoesNotFireOnNestedNonCountAggregate() .globalGrouping() .source( p.project( - Assignments.of(totalPrice, totalPrice.toSymbolReference()), + AssignmentsUtils.of(totalPrice, totalPrice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java index 3a68548a4398f..65559bc59f7c5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.base.Predicates; @@ -89,7 +89,7 @@ private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate Symbol rightValue = p.symbol("rightValue"); List outputs = ImmutableList.of(leftValue, rightValue); return p.project( - Assignments.identity( + AssignmentsUtils.identity( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java index 3c3df254f4370..52b86f89429e0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -70,7 +70,7 @@ private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate 5"), planBuilder.values(a, b))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java index c49d2707f28fb..5b6977a51a3b6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -18,10 +18,10 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; @@ -78,7 +78,7 @@ private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate proj Symbol rightValue = p.symbol("rightValue"); List outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); return p.project( - Assignments.identity( + AssignmentsUtils.identity( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java index a2c79c372da49..baff57fc2c210 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -62,7 +62,7 @@ private ProjectNode buildProjectedLimit(PlanBuilder planBuilder, Predicate Symbol rightKey = p.symbol("rightKey"); List outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); return p.project( - Assignments.identity( + AssignmentsUtils.identity( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java index 5127108823a9e..35aab979d0328 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -15,10 +15,10 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; @@ -45,7 +45,7 @@ public void testNotAllOutputsReferenced() Symbol orderdate = p.symbol("orderdate", DATE); Symbol totalprice = p.symbol("totalprice", DOUBLE); return p.project( - Assignments.of(p.symbol("x"), totalprice.toSymbolReference()), + AssignmentsUtils.of(p.symbol("x"), totalprice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), @@ -67,7 +67,7 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneTableScanColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), + AssignmentsUtils.of(p.symbol("y"), expression("x")), p.tableScan( ImmutableList.of(p.symbol("x")), ImmutableMap.of(p.symbol("x"), new TestingColumnHandle("x"))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java index 486ef116d2c4b..3a006874ab5d7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; @@ -76,7 +76,7 @@ private ProjectNode buildProjectedTopN(PlanBuilder planBuilder, Predicate p.project( - Assignments.of(p.symbol("y"), expression("x")), + AssignmentsUtils.of(p.symbol("y"), expression("x")), p.values( ImmutableList.of(p.symbol("unused"), p.symbol("x")), ImmutableList.of( @@ -57,7 +57,7 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), + AssignmentsUtils.of(p.symbol("y"), expression("x")), p.values(p.symbol("x")))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index 8427deb005342..bd399410d532b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -15,13 +15,13 @@ import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.tree.Expression; @@ -202,7 +202,7 @@ private static PlanNode buildProjectedWindow( List outputs = ImmutableList.builder().addAll(inputs).add(output1, output2).build(); return p.project( - Assignments.identity( + AssignmentsUtils.identity( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index d6307c1bd4b6c..fa027c71daa25 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -148,7 +148,7 @@ public void testDoesNotFireWhenNotDistinct() .source( p.join( JoinNode.Type.LEFT, - p.project(Assignments.builder() + p.project(AssignmentsUtils.builder() .putIdentity(p.symbol("COL1", BIGINT)) .build(), p.aggregation(builder -> diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java index d6ea56e937a4b..642042fe97929 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -36,7 +36,7 @@ public void testPushdownLimitNonIdentityProjection() Symbol a = p.symbol("a"); return p.limit(1, p.project( - Assignments.of(a, TRUE_LITERAL), + AssignmentsUtils.of(a, TRUE_LITERAL), p.values())); }) .matches( @@ -53,7 +53,7 @@ public void testDoesntPushdownLimitThroughIdentityProjection() Symbol a = p.symbol("a"); return p.limit(1, p.project( - Assignments.of(a, a.toSymbolReference()), + AssignmentsUtils.of(a, a.toSymbolReference()), p.values(a))); }).doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index 0a0e7c1082041..5463419235431 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.SymbolReference; @@ -44,7 +44,7 @@ public void testDoesNotFireNoExchange() tester().assertThat(new PushProjectionThroughExchange()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new LongLiteral("3")), + AssignmentsUtils.of(p.symbol("x"), new LongLiteral("3")), p.values(p.symbol("a")))) .doesNotFire(); } @@ -59,7 +59,7 @@ public void testDoesNotFireNarrowingProjection() Symbol c = p.symbol("c"); return p.project( - Assignments.builder() + AssignmentsUtils.builder() .put(a, a.toSymbolReference()) .put(b, b.toSymbolReference()) .build(), @@ -82,7 +82,7 @@ public void testSimpleMultipleInputs() Symbol c2 = p.symbol("c2"); Symbol x = p.symbol("x"); return p.project( - Assignments.of( + AssignmentsUtils.of( x, new LongLiteral("3"), c2, new SymbolReference("c")), p.exchange(e -> e @@ -119,7 +119,7 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() Symbol bTimes5 = p.symbol("b_times_5"); Symbol hTimes5 = p.symbol("h_times_5"); return p.project( - Assignments.builder() + AssignmentsUtils.builder() .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) .put(bTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5"))) .put(hTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5"))) @@ -163,7 +163,7 @@ public void testOrderingColumnsArePreserved() Symbol sortSymbol = p.symbol("sortSymbol"); OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortSymbol), ImmutableMap.of(sortSymbol, SortOrder.ASC_NULLS_FIRST)); return p.project( - Assignments.builder() + AssignmentsUtils.builder() .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) .put(bTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5"))) .put(hTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5"))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index 2b22593cb6fa5..093f29267ed12 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.google.common.collect.ImmutableList; @@ -37,7 +37,7 @@ public void testDoesNotFire() tester().assertThat(new PushProjectionThroughUnion()) .on(p -> p.project( - Assignments.of(p.symbol("x"), new LongLiteral("3")), + AssignmentsUtils.of(p.symbol("x"), new LongLiteral("3")), p.values(p.symbol("a")))) .doesNotFire(); } @@ -52,7 +52,7 @@ public void test() Symbol c = p.symbol("c"); Symbol cTimes3 = p.symbol("c_times_3"); return p.project( - Assignments.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), + AssignmentsUtils.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), p.union( ImmutableListMultimap.builder() .put(c, a) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java index 1532466c06cbb..16450180a9e0a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -29,7 +29,7 @@ public void testDoesNotFire() { tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) .on(p -> p.apply( - Assignments.of(p.symbol("z"), p.expression("x IN (y)")), + AssignmentsUtils.of(p.symbol("z"), p.expression("x IN (y)")), ImmutableList.of(), p.values(p.symbol("x")), p.values(p.symbol("y")))) @@ -41,7 +41,7 @@ public void testEmptyAssignments() { tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) .on(p -> p.apply( - Assignments.of(), + AssignmentsUtils.of(), ImmutableList.of(), p.values(p.symbol("x")), p.values(p.symbol("y")))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java index afbc3ab6b62ba..923358411fbca 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -106,7 +106,7 @@ public void rewritesOnSubqueryWithProjection() .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), + p.project(AssignmentsUtils.of(p.symbol("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index cd70861189a7c..d7d17ade33556 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -17,10 +17,10 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -124,7 +124,7 @@ public void rewritesOnSubqueryWithProjection() p.values(p.symbol("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), p.expression("a * 2")), + AssignmentsUtils.of(p.symbol("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS)))))) @@ -153,10 +153,10 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project( - Assignments.of(p.symbol("a3"), p.expression("a2 + 1")), + AssignmentsUtils.of(p.symbol("a3"), p.expression("a2 + 1")), p.enforceSingleRow( p.project( - Assignments.of(p.symbol("a2"), p.expression("a * 2")), + AssignmentsUtils.of(p.symbol("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), TWO_ROWS))))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 94612fac9ba78..f4c00bdf525c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -15,9 +15,9 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; @@ -54,7 +54,7 @@ public void testRewrite() ImmutableMap.of(p.symbol("l_nationkey"), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.symbol("l_expr2"), expression("l_nationkey + 1")), + AssignmentsUtils.of(p.symbol("l_expr2"), expression("l_nationkey + 1")), p.values( ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java index 5326704f45481..7f944f3b96c7f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -56,7 +56,7 @@ public void testRewrite() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + AssignmentsUtils.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), ImmutableList.of(), p.values(), p.values())) @@ -75,10 +75,10 @@ public void testRewritesToLimit() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + AssignmentsUtils.of(p.symbol("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), - p.project(Assignments.of(), + p.project(AssignmentsUtils.of(), p.filter( expression("corr = column"), p.values(p.symbol("column")))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 47255b4fc74b4..5d3f7f56e8c68 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.SemiJoinNode; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.InPredicate; @@ -34,7 +34,7 @@ public void testDoesNotFireOnNoCorrelation() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of(), + AssignmentsUtils.of(), emptyList(), p.values(), p.values())) @@ -46,7 +46,7 @@ public void testDoesNotFireOnNonInPredicateSubquery() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of(p.symbol("x"), new ExistsPredicate(new LongLiteral("1"))), + AssignmentsUtils.of(p.symbol("x"), new ExistsPredicate(new LongLiteral("1"))), emptyList(), p.values(), p.values())) @@ -58,7 +58,7 @@ public void testFiresForInPredicate() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of( + AssignmentsUtils.of( p.symbol("x"), new InPredicate( new SymbolReference("y"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java index 729c330d3fb02..9644f67a2a192 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java @@ -15,8 +15,8 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -35,7 +35,7 @@ public void testReportWrongMatch() tester.assertThat(new DummyReplaceNodeRule()) .on(p -> p.project( - Assignments.of(p.symbol("y"), expression("x")), + AssignmentsUtils.of(p.symbol("y"), expression("x")), p.values( ImmutableList.of(p.symbol("x")), ImmutableList.of(constantExpressions(BIGINT, 1))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java index 4f121265984e5..3dae0d7b3e9f0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.google.common.collect.ImmutableCollection; import org.testng.annotations.Test; @@ -22,7 +23,7 @@ public class TestAssingments { - private final Assignments assignments = Assignments.of(new Symbol("test"), TRUE_LITERAL); + private final Assignments assignments = AssignmentsUtils.of(new Symbol("test"), TRUE_LITERAL); @Test public void testOutputsImmutable() diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java index f1f448d45b1af..89c2edb07d859 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestVerifyOnlyOneOutputNode.java @@ -14,9 +14,9 @@ package com.facebook.presto.sql.planner.sanity; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; @@ -38,7 +38,7 @@ public void testValidateSuccessful() new ProjectNode(idAllocator.getNextId(), new ValuesNode( idAllocator.getNextId(), ImmutableList.of(), ImmutableList.of()), - Assignments.of() + AssignmentsUtils.of() ), ImmutableList.of(), ImmutableList.of()); new VerifyOnlyOneOutputNode().validate(root, null, null, null, null, WarningCollector.NOOP); } @@ -54,7 +54,7 @@ public void testValidateFailed() new ProjectNode(idAllocator.getNextId(), new ValuesNode( idAllocator.getNextId(), ImmutableList.of(), ImmutableList.of()), - Assignments.of() + AssignmentsUtils.of() ), ImmutableList.of(), ImmutableList.of() ), new Symbol("a"), false), From 868415050bb9571180ee3406906affd5ae5fa664 Mon Sep 17 00:00:00 2001 From: James Sun Date: Mon, 22 Apr 2019 18:28:49 -0700 Subject: [PATCH 2/6] Extract ApplyNode::isSupportedSubqueryExpression to uility --- .../presto/sql/planner/SubqueryPlanner.java | 2 + .../rule/ExpressionRewriteRuleSet.java | 2 + .../planner/optimizations/ApplyNodeUtil.java | 43 +++++++++++++++++++ .../PruneUnreferencedOutputs.java | 6 ++- .../UnaliasSymbolReferences.java | 5 ++- .../presto/sql/planner/plan/ApplyNode.java | 14 ------ .../iterative/rule/test/PlanBuilder.java | 2 + 7 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index 4c1eb93b7cf4e..d9e089b90adf4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -59,6 +59,7 @@ import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.analyzer.SemanticExceptions.subQueryNotSupportedError; import static com.facebook.presto.sql.planner.ExpressionNodeInliner.replaceExpression; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -439,6 +440,7 @@ private PlanBuilder appendApplyNode(PlanBuilder subPlan, Node subquery, PlanNode TranslationMap translations = subPlan.copyTranslations(); PlanNode root = subPlan.getRoot(); + verifySubquerySupported(subqueryAssignments); return new PlanBuilder(translations, new ApplyNode(idAllocator.getNextId(), root, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index 5a6b4738823e8..3b8e5fed0ed42 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -38,6 +38,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.planner.plan.Patterns.filter; @@ -316,6 +317,7 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) if (applyNode.getSubqueryAssignments().equals(subqueryAssignments)) { return Result.empty(); } + verifySubquerySupported(subqueryAssignments); return Result.ofPlanNode(new ApplyNode( applyNode.getId(), applyNode.getInput(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java new file mode 100644 index 0000000000000..7f45d6feced8c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.ExistsPredicate; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; + +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.google.common.base.Preconditions.checkArgument; + +public final class ApplyNodeUtil +{ + private ApplyNodeUtil() {} + + public static void verifySubquerySupported(Assignments assignments) + { + checkArgument( + assignments.getExpressions().stream().allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), + "Unexpected expression used for subquery expression"); + } + + private static boolean isSupportedSubqueryExpression(RowExpression expression) + { + // TODO: add RowExpression support + return castToExpression(expression) instanceof InPredicate || + castToExpression(expression) instanceof ExistsPredicate || + castToExpression(expression) instanceof QuantifiedComparisonExpression; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 27c377c73146b..5c9b82264726b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -28,6 +28,7 @@ import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.DistinctLimitNode; import com.facebook.presto.sql.planner.plan.ExceptNode; @@ -83,6 +84,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -801,7 +803,9 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) .addAll(subqueryAssignmentsSymbols) // need to include those: e.g: "expr" from "expr IN (SELECT 1)" .build(); PlanNode input = context.rewrite(node.getInput(), inputContext); - return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation, node.getOriginSubqueryError()); + Assignments assignments = subqueryAssignments.build(); + verifySubquerySupported(assignments); + return new ApplyNode(node.getId(), input, subquery, assignments, newCorrelation, node.getOriginSubqueryError()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 98bf63beea718..bc30dd64064d2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -89,6 +89,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -456,7 +457,9 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) PlanNode subquery = context.rewrite(node.getSubquery()); List canonicalCorrelation = Lists.transform(node.getCorrelation(), this::canonicalize); - return new ApplyNode(node.getId(), source, subquery, canonicalize(node.getSubqueryAssignments()), canonicalCorrelation, node.getOriginSubqueryError()); + Assignments assignments = canonicalize(node.getSubqueryAssignments()); + verifySubquerySupported(assignments); + return new ApplyNode(node.getId(), source, subquery, assignments, canonicalCorrelation, node.getOriginSubqueryError()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index 356ac9bcd124c..aab3bf7c9aa10 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -14,10 +14,6 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.tree.ExistsPredicate; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.InPredicate; -import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -84,9 +80,6 @@ public ApplyNode( requireNonNull(originSubqueryError, "originSubqueryError is null"); checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); - checkArgument( - subqueryAssignments.getExpressions().stream().allMatch(ApplyNode::isSupportedSubqueryExpression), - "Unexpected expression used for subquery expression"); this.input = input; this.subquery = subquery; @@ -95,13 +88,6 @@ public ApplyNode( this.originSubqueryError = originSubqueryError; } - private static boolean isSupportedSubqueryExpression(Expression expression) - { - return expression instanceof InPredicate || - expression instanceof ExistsPredicate || - expression instanceof QuantifiedComparisonExpression; - } - @JsonProperty("input") public PlanNode getInput() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index ce9eeef29bac9..e9ce53776ba18 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -97,6 +97,7 @@ import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.constantNull; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -361,6 +362,7 @@ protected AggregationNode build() public ApplyNode apply(Assignments subqueryAssignments, List correlation, PlanNode input, PlanNode subquery) { + verifySubquerySupported(subqueryAssignments); return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation, ""); } From d4f9b03c220edc8e20fb9521870a11becd2beec0 Mon Sep 17 00:00:00 2001 From: James Sun Date: Mon, 22 Apr 2019 00:06:59 -0700 Subject: [PATCH 3/6] Replace Assignments::Expression with RowExpression --- .../presto/cost/ProjectStatsRule.java | 14 ++- .../presto/sql/planner/AssignmentsUtils.java | 29 +++--- .../planner/EffectivePredicateExtractor.java | 4 +- .../sql/planner/ExpressionExtractor.java | 3 +- .../sql/planner/LocalExecutionPlanner.java | 48 ++++------ .../presto/sql/planner/LogicalPlanner.java | 7 +- .../presto/sql/planner/PlanBuilder.java | 5 +- .../presto/sql/planner/PlanFragmenter.java | 20 +++-- .../presto/sql/planner/QueryPlanner.java | 19 ++-- .../presto/sql/planner/RelationPlanner.java | 22 ++--- .../presto/sql/planner/SubqueryPlanner.java | 6 +- .../iterative/rule/EliminateCrossJoins.java | 4 +- .../iterative/rule/ExtractSpatialJoins.java | 5 +- .../iterative/rule/GatherAndMergeWindows.java | 8 +- .../rule/ImplementFilteredAggregations.java | 2 +- .../iterative/rule/InlineProjections.java | 30 ++++--- .../rule/ProjectOffPushDownRule.java | 4 +- .../rule/PushAggregationThroughOuterJoin.java | 4 +- ...PushPartialAggregationThroughExchange.java | 3 +- .../rule/PushProjectionThroughExchange.java | 20 +++-- .../rule/PushProjectionThroughUnion.java | 9 +- ...RewriteSpatialPartitioningAggregation.java | 10 ++- .../rule/SimplifyCountOverConstant.java | 3 +- .../TransformCorrelatedInPredicateToJoin.java | 7 +- .../TransformExistsApplyToLateralNode.java | 10 ++- ...rrelatedInPredicateSubqueryToSemiJoin.java | 3 +- .../planner/optimizations/AddExchanges.java | 7 +- .../HashGenerationOptimizer.java | 17 ++-- .../ImplementIntersectAndExceptAsUnion.java | 4 +- .../optimizations/IndexJoinOptimizer.java | 8 +- .../optimizations/MetadataQueryOptimizer.java | 4 +- .../OptimizeMixedDistinctAggregations.java | 11 +-- .../optimizations/PredicatePushDown.java | 34 +++---- .../optimizations/PropertyDerivations.java | 78 +++++++++++----- .../PruneUnreferencedOutputs.java | 15 +++- .../ScalarAggregationToJoinRewriter.java | 3 +- .../StreamPropertyDerivations.java | 21 +++-- ...uantifiedComparisonApplyToLateralJoin.java | 6 +- .../optimizations/TranslateExpressions.java | 88 ++++++++++++++++++- .../UnaliasSymbolReferences.java | 6 +- .../optimizations/joins/JoinGraph.java | 4 +- .../presto/sql/planner/plan/Assignments.java | 16 ++-- .../sql/planner/planPrinter/PlanPrinter.java | 8 +- .../sql/planner/sanity/TypeValidator.java | 26 ++++-- .../sanity/ValidateDependenciesChecker.java | 22 +++-- .../sql/relational/ProjectNodeUtils.java | 7 +- .../facebook/presto/util/GraphvizPrinter.java | 9 +- 47 files changed, 445 insertions(+), 248 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java index cb6ad1e739ba4..1587f22844f76 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java @@ -15,16 +15,18 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.tree.Expression; import java.util.Map; import java.util.Optional; import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static java.util.Objects.requireNonNull; public class ProjectStatsRule @@ -53,8 +55,14 @@ protected Optional doCalculate(ProjectNode node, StatsPro PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(sourceStats.getOutputRowCount()); - for (Map.Entry entry : node.getAssignments().entrySet()) { - calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types)); + for (Map.Entry entry : node.getAssignments().entrySet()) { + RowExpression expression = entry.getValue(); + if (isExpression(expression)) { + calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(castToExpression(expression), sourceStats, session, types)); + } + else { + calculatedStats.addSymbolStatistics(entry.getKey(), scalarStatsCalculator.calculate(expression, sourceStats, session)); + } } return Optional.of(calculatedStats.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java index 97d4deb4b8cdd..eda555860963f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -27,6 +28,8 @@ import java.util.function.Predicate; import java.util.stream.Collector; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkState; import static java.util.Arrays.asList; @@ -53,7 +56,7 @@ public static Assignments identity(Iterable symbols) return builder().putIdentities(symbols).build(); } - public static Assignments copyOf(Map assignments) + public static Assignments copyOf(Map assignments) { return builder() .putAll(assignments) @@ -65,12 +68,12 @@ public static Assignments of() return builder().build(); } - public static Assignments of(Symbol symbol, Expression expression) + public static Assignments of(Symbol symbol, RowExpression expression) { return builder().put(symbol, expression).build(); } - public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2) + public static Assignments of(Symbol symbol1, RowExpression expression1, Symbol symbol2, RowExpression expression2) { return builder().put(symbol1, expression1).put(symbol2, expression2).build(); } @@ -84,7 +87,7 @@ public static Assignments rewrite(Assignments assignments, ExpressionRewrite public static Assignments rewrite(Assignments assignments, Function rewrite) { return assignments.entrySet().stream() - .map(entry -> Maps.immutableEntry(entry.getKey(), rewrite.apply(entry.getValue()))) + .map(entry -> Maps.immutableEntry(entry.getKey(), castToRowExpression(rewrite.apply(castToExpression(entry.getValue()))))) .collect(toAssignments()); } @@ -102,12 +105,12 @@ public static Assignments filter(Assignments assignments, Predicate pred public static boolean isIdentity(Assignments assignments, Symbol output) { - Expression expression = assignments.get(output); + Expression expression = castToExpression(assignments.get(output)); return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); } - private static Collector, Builder, Assignments> toAssignments() + private static Collector, Builder, Assignments> toAssignments() { return Collector.of( AssignmentsUtils::builder, @@ -122,25 +125,25 @@ private static Collector, Builder, Assignments> to // Originally, the following class is also static public static class Builder { - private final Map assignments = new LinkedHashMap<>(); + private final Map assignments = new LinkedHashMap<>(); public Builder putAll(Assignments assignments) { return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Map.Entry assignment : assignments.entrySet()) { + for (Map.Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; } - public Builder put(Symbol symbol, Expression expression) + public Builder put(Symbol symbol, RowExpression expression) { if (assignments.containsKey(symbol)) { - Expression assignment = assignments.get(symbol); + RowExpression assignment = assignments.get(symbol); checkState( assignment.equals(expression), "Symbol %s already has assignment %s, while adding %s", @@ -152,7 +155,7 @@ public Builder put(Symbol symbol, Expression expression) return this; } - public Builder put(Map.Entry assignment) + public Builder put(Map.Entry assignment) { put(assignment.getKey(), assignment.getValue()); return this; @@ -168,7 +171,7 @@ public Builder putIdentities(Iterable symbols) public Builder putIdentity(Symbol symbol) { - put(symbol, symbol.toSymbolReference()); + put(symbol, castToRowExpression(symbol.toSymbolReference())); return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index 96b955ffefee7..2a6e839df2fc7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -31,6 +31,7 @@ import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.UnionNode; import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; @@ -59,6 +60,7 @@ import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.util.Objects.requireNonNull; /** @@ -159,7 +161,7 @@ public Expression visitProject(ProjectNode node, Void context) Expression underlyingPredicate = node.getSource().accept(this, context); - List projectionEqualities = node.getAssignments().entrySet().stream() + List projectionEqualities = transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java index 62310b0d89a93..ff91474279c13 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java @@ -106,7 +106,7 @@ public Void visitFilter(FilterNode node, ImmutableList.Builder co @Override public Void visitProject(ProjectNode node, ImmutableList.Builder context) { - context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())); + context.addAll(node.getAssignments().getExpressions().stream().collect(toImmutableList())); return super.visitProject(node, context); } @@ -130,7 +130,6 @@ public Void visitApply(ApplyNode node, ImmutableList.Builder cont context.addAll(node.getSubqueryAssignments() .getExpressions() .stream() - .map(OriginalExpressionUtils::castToRowExpression) .collect(toImmutableList())); return super.visitApply(node, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index c0384e1374265..969ed66aec585 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -109,6 +109,7 @@ import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.FunctionType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spiller.PartitioningSpillerFactory; @@ -273,7 +274,6 @@ import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; import static io.airlift.units.DataSize.Unit.BYTE; @@ -1143,7 +1143,12 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext RowExpression filterExpression = node.getPredicate(); List outputSymbols = node.getOutputSymbols(); - return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), AssignmentsUtils.identity(outputSymbols), outputSymbols); + AssignmentsUtils.Builder identities = AssignmentsUtils.builder(); + for (Symbol symbol : outputSymbols) { + Type type = requireNonNull(context.getTypes().get(symbol), format("No type for symbol %s", symbol)); + identities.put(symbol, new VariableReferenceExpression(symbol.getName(), type)); + } + return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identities.build(), outputSymbols); } @Override @@ -1218,30 +1223,15 @@ private PhysicalOperation visitScanFilterAndProject( Map outputMappings = outputMappingsBuilder.build(); // compiler uses inputs instead of symbols, so rewrite the expressions first - - List projections = new ArrayList<>(); - for (Symbol symbol : outputSymbols) { - projections.add(assignments.get(symbol)); - } - - Map, Type> expressionTypes = getExpressionTypes( - context.getSession(), - metadata, - sqlParser, - context.getTypes(), - concat(assignments.getExpressions()), - emptyList(), - NOOP, - false); - - List translatedProjections = projections.stream() - .map(expression -> toRowExpression(expression, expressionTypes, sourceLayout)) + List projections = outputSymbols.stream() + .map(assignments::get) + .map(expression -> bindChannels(expression, sourceLayout)) .collect(toImmutableList()); try { if (columns != null) { - Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, translatedProjections, sourceNode.getId()); - Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); + Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, projections, sourceNode.getId()); + Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId)); SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory( context.getNextOperatorId(), @@ -1251,20 +1241,20 @@ private PhysicalOperation visitScanFilterAndProject( cursorProcessor, pageProcessor, columns, - getTypes(projections, expressionTypes), + projections.stream().map(RowExpression::getType).collect(toImmutableList()), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION); } else { - Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); + Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId)); OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, pageProcessor, - getTypes(projections, expressionTypes), + projections.stream().map(RowExpression::getType).collect(toImmutableList()), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -2796,14 +2786,6 @@ private OperatorFactory createHashAggregationOperatorFactory( } } - private static List getTypes(List expressions, Map, Type> expressionTypes) - { - return expressions.stream() - .map(NodeRef::of) - .map(expressionTypes::get) - .collect(toImmutableList()); - } - private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 505a3a7d42a90..9b3645017c683 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -96,6 +96,7 @@ import static com.facebook.presto.sql.planner.plan.TableWriterNode.WriterTarget; import static com.facebook.presto.sql.planner.sanity.PlanSanityChecker.DISTRIBUTED_PLAN_SANITY_CHECKER; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -354,7 +355,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) int index = insert.getColumns().indexOf(columns.get(column.getName())); if (index < 0) { Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString()); - assignments.put(output, cast); + assignments.put(output, castToRowExpression(cast)); } else { Symbol input = plan.getSymbol(index); @@ -362,11 +363,11 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) Type queryType = symbolAllocator.getTypes().get(input); if (queryType.equals(tableType) || metadata.getTypeManager().isTypeOnlyCoercion(queryType, tableType)) { - assignments.put(output, input.toSymbolReference()); + assignments.put(output, castToRowExpression(input.toSymbolReference())); } else { Expression cast = new Cast(input.toSymbolReference(), tableType.getTypeSignature().toString()); - assignments.put(output, cast); + assignments.put(output, castToRowExpression(cast)); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index 48b9a7933f6f3..977621b25262b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static java.util.Objects.requireNonNull; class PlanBuilder @@ -96,13 +97,13 @@ public PlanBuilder appendProjections(Iterable expressions, SymbolAll // add an identity projection for underlying plan for (Symbol symbol : getRoot().getOutputSymbols()) { - projections.put(symbol, symbol.toSymbolReference()); + projections.put(symbol, castToRowExpression(symbol.toSymbolReference())); } ImmutableMap.Builder newTranslations = ImmutableMap.builder(); for (Expression expression : expressions) { Symbol symbol = symbolAllocator.newSymbol(expression, getAnalysis().getTypeWithCoercions(expression)); - projections.put(symbol, translations.rewrite(expression)); + projections.put(symbol, castToRowExpression(translations.rewrite(expression))); newTranslations.put(symbol, expression); } // Now append the new translations into the TranslationMap diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index bd099bba56411..07ad7cd8a7322 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -41,6 +41,8 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; @@ -67,7 +69,6 @@ import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.sanity.PlanSanityChecker; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -102,6 +103,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan; +import static com.facebook.presto.sql.relational.Expressions.constant; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.in; @@ -493,13 +495,13 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite private PartitioningSymbolAssignments assignPartitioningSymbols(Partitioning partitioning) { ImmutableList.Builder symbols = ImmutableList.builder(); - ImmutableMap.Builder constants = ImmutableMap.builder(); + ImmutableMap.Builder constants = ImmutableMap.builder(); for (ArgumentBinding argumentBinding : partitioning.getArguments()) { Symbol symbol; if (argumentBinding.isConstant()) { NullableValue constant = argumentBinding.getConstant(); - Expression expression = literalEncoder.toExpression(constant.getValue(), constant.getType()); - symbol = symbolAllocator.newSymbol(expression, constant.getType()); + RowExpression expression = constant(constant.getValue(), constant.getType()); + symbol = symbolAllocator.newSymbol("constant_partition", constant.getType()); constants.put(symbol, expression); } else { @@ -569,7 +571,7 @@ private TableFinishNode createTemporaryTableWrite( List outputs, List> inputs, List sources, - Map constantExpressions, + Map constantExpressions, PartitioningMetadata partitioningMetadata) { if (!constantExpressions.isEmpty()) { @@ -593,7 +595,7 @@ private TableFinishNode createTemporaryTableWrite( sources = sources.stream() .map(source -> { AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); - assignments.putIdentities(source.getOutputSymbols()); + source.getOutputSymbols().forEach(symbol -> assignments.put(symbol, new VariableReferenceExpression(symbol.getName(), symbolAllocator.getTypes().get(symbol)))); constantSymbols.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol))); return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); }) @@ -1082,9 +1084,9 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) private static class PartitioningSymbolAssignments { private final List symbols; - private final Map constants; + private final Map constants; - private PartitioningSymbolAssignments(List symbols, Map constants) + private PartitioningSymbolAssignments(List symbols, Map constants) { this.symbols = ImmutableList.copyOf(requireNonNull(symbols, "symbols is null")); this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null")); @@ -1098,7 +1100,7 @@ public List getSymbols() return symbols; } - public Map getConstants() + public Map getConstants() { return constants; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index ad457368d571f..57827c8f45c8c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.block.SortOrder; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Field; @@ -336,13 +337,13 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression for (Expression expression : expressions) { if (expression instanceof SymbolReference) { Symbol symbol = Symbol.from(expression); - projections.put(symbol, expression); + projections.put(symbol, castToRowExpression(expression)); outputTranslations.put(expression, symbol); continue; } Symbol symbol = symbolAllocator.newSymbol(expression, analysis.getTypeWithCoercions(expression)); - projections.put(symbol, subPlan.rewrite(expression)); + projections.put(symbol, castToRowExpression(subPlan.rewrite(expression))); outputTranslations.put(expression, symbol); } @@ -353,9 +354,9 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression analysis.getParameters()); } - private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { - ImmutableMap.Builder projections = ImmutableMap.builder(); + ImmutableMap.Builder projections = ImmutableMap.builder(); for (Expression expression : expressions) { Type type = analysis.getType(expression); @@ -369,7 +370,7 @@ private Map coerce(Iterable expression false, metadata.getTypeManager().isTypeOnlyCoercion(type, coercion)); } - projections.put(symbol, rewritten); + projections.put(symbol, castToRowExpression(rewritten)); translations.put(expression, symbol); } @@ -388,13 +389,13 @@ private PlanBuilder explicitCoercionFields(PlanBuilder subPlan, Iterable assignments.put(key, value.toSymbolReference())); + groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(value.toSymbolReference()))); ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build()); subPlan = new PlanBuilder(groupingTranslations, project, analysis.getParameters()); @@ -681,7 +682,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica false, metadata.getTypeManager().isTypeOnlyCoercion(analysis.getType(groupingOperation), coercion)); } - projections.put(symbol, rewritten); + projections.put(symbol, castToRowExpression(rewritten)); newTranslations.put(groupingOperation, symbol); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 4f0457d3a275b..10cc44e8ceb01 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -181,7 +181,7 @@ protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context) Field field = subPlan.getDescriptor().getFieldByIndex(i); if (!field.isHidden()) { Symbol aliasedColumn = symbolAllocator.newSymbol(field); - assignments.put(aliasedColumn, subPlan.getFieldMappings().get(i).toSymbolReference()); + assignments.put(aliasedColumn, castToRowExpression(subPlan.getFieldMappings().get(i).toSymbolReference())); newMappings.add(aliasedColumn); } } @@ -440,21 +440,21 @@ If casts are redundant (due to column type and common type being equal), // compute the coercion for the field on the left to the common supertype of left & right Symbol leftOutput = symbolAllocator.newSymbol(identifier, type); int leftField = joinAnalysis.getLeftJoinFields().get(i); - leftCoercions.put(leftOutput, new Cast( + leftCoercions.put(leftOutput, castToRowExpression(new Cast( left.getSymbol(leftField).toSymbolReference(), type.getTypeSignature().toString(), false, - metadata.getTypeManager().isTypeOnlyCoercion(left.getDescriptor().getFieldByIndex(leftField).getType(), type))); + metadata.getTypeManager().isTypeOnlyCoercion(left.getDescriptor().getFieldByIndex(leftField).getType(), type)))); leftJoinColumns.put(identifier, leftOutput); // compute the coercion for the field on the right to the common supertype of left & right Symbol rightOutput = symbolAllocator.newSymbol(identifier, type); int rightField = joinAnalysis.getRightJoinFields().get(i); - rightCoercions.put(rightOutput, new Cast( + rightCoercions.put(rightOutput, castToRowExpression(new Cast( right.getSymbol(rightField).toSymbolReference(), type.getTypeSignature().toString(), false, - metadata.getTypeManager().isTypeOnlyCoercion(right.getDescriptor().getFieldByIndex(rightField).getType(), type))); + metadata.getTypeManager().isTypeOnlyCoercion(right.getDescriptor().getFieldByIndex(rightField).getType(), type)))); rightJoinColumns.put(identifier, rightOutput); clauses.add(new JoinNode.EquiJoinClause(leftOutput, rightOutput)); @@ -486,21 +486,21 @@ If casts are redundant (due to column type and common type being equal), for (Identifier column : joinColumns) { Symbol output = symbolAllocator.newSymbol(column, analysis.getType(column)); outputs.add(output); - assignments.put(output, new CoalesceExpression( + assignments.put(output, castToRowExpression(new CoalesceExpression( leftJoinColumns.get(column).toSymbolReference(), - rightJoinColumns.get(column).toSymbolReference())); + rightJoinColumns.get(column).toSymbolReference()))); } for (int field : joinAnalysis.getOtherLeftFields()) { Symbol symbol = left.getFieldMappings().get(field); outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + assignments.put(symbol, castToRowExpression(symbol.toSymbolReference())); } for (int field : joinAnalysis.getOtherRightFields()) { Symbol symbol = right.getFieldMappings().get(field); outputs.add(symbol); - assignments.put(symbol, symbol.toSymbolReference()); + assignments.put(symbol, castToRowExpression(symbol.toSymbolReference())); } return new RelationPlan( @@ -730,13 +730,13 @@ private RelationPlan addCoercions(RelationPlan plan, Type[] targetColumnTypes) if (!outputType.equals(inputType)) { Expression cast = new Cast(inputSymbol.toSymbolReference(), outputType.getTypeSignature().toString()); Symbol outputSymbol = symbolAllocator.newSymbol(cast, outputType); - assignments.put(outputSymbol, cast); + assignments.put(outputSymbol, castToRowExpression(cast)); newSymbols.add(outputSymbol); } else { SymbolReference symbolReference = inputSymbol.toSymbolReference(); Symbol outputSymbol = symbolAllocator.newSymbol(symbolReference, outputType); - assignments.put(outputSymbol, symbolReference); + assignments.put(outputSymbol, castToRowExpression(symbolReference)); newSymbols.add(outputSymbol); } Field oldField = oldDescriptor.getFieldByIndex(i); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index d9e089b90adf4..822a7d085385a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -199,7 +199,7 @@ private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate subPlan.getTranslations().put(inPredicate, inPredicateSubquerySymbol); - return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), AssignmentsUtils.of(inPredicateSubquerySymbol, inPredicateSubqueryExpression), correlationAllowed); + return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), AssignmentsUtils.of(inPredicateSubquerySymbol, castToRowExpression(inPredicateSubqueryExpression)), correlationAllowed); } private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set scalarSubqueries, boolean correlationAllowed) @@ -298,7 +298,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred subPlan, existsPredicate.getSubquery(), subqueryNode, - AssignmentsUtils.of(exists, rewrittenExistsPredicate), + AssignmentsUtils.of(exists, castToRowExpression(rewrittenExistsPredicate)), correlationAllowed); } @@ -396,7 +396,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa subPlan, quantifiedComparison.getSubquery(), subqueryPlan.getRoot(), - AssignmentsUtils.of(coercedQuantifiedComparisonSymbol, coercedQuantifiedComparison), + AssignmentsUtils.of(coercedQuantifiedComparisonSymbol, castToRowExpression(coercedQuantifiedComparison)), correlationAllowed); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index e3f112d56ee58..d3a9a251136e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -48,6 +49,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -203,7 +205,7 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra result = new ProjectNode( idAllocator.getNextId(), result, - AssignmentsUtils.copyOf(graph.getAssignments().get())); + AssignmentsUtils.copyOf(transformValues(graph.getAssignments().get(), OriginalExpressionUtils::castToRowExpression))); } // If needed, introduce a projection to constrain the outputs to what was originally expected diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index c0d8c684e000a..949282b0edab3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -93,6 +93,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.filter; import static com.facebook.presto.sql.planner.plan.Patterns.join; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.LESS_THAN; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL; import static com.facebook.presto.util.SpatialJoinUtils.extractSupportedSpatialComparisons; @@ -589,7 +590,7 @@ private static PlanNode addProjection(Context context, PlanNode node, Symbol sym projections.putIdentity(outputSymbol); } - projections.put(symbol, expression); + projections.put(symbol, castToRowExpression(expression)); return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build()); } @@ -607,7 +608,7 @@ private static PlanNode addPartitioningNodes(Context context, PlanNode node, Sym FunctionCall partitioningFunction = new FunctionCall(QualifiedName.of("spatial_partitions"), partitioningArguments.build()); Symbol partitionsSymbol = context.getSymbolAllocator().newSymbol(partitioningFunction, new ArrayType(INTEGER)); - projections.put(partitionsSymbol, partitioningFunction); + projections.put(partitionsSymbol, castToRowExpression(partitioningFunction)); return new UnnestNode( context.getIdAllocator().getNextId(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index 4c80553b961b6..235e37c97136e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -17,6 +17,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.PropertyPattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; @@ -26,7 +27,7 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -49,6 +50,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.window; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.stream.Collectors.toList; public class GatherAndMergeWindows { @@ -139,7 +141,7 @@ protected static Optional pullWindowNodeAboveProjects( // The only kind of use of the output of the target that we can safely ignore is a simple identity propagation. // The target node, when hoisted above the projections, will provide the symbols directly. - Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( + Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( project.getAssignments().getMap(), output -> !(AssignmentsUtils.isIdentity(project.getAssignments(), output) && targetOutputs.contains(output))); @@ -153,7 +155,7 @@ protected static Optional pullWindowNodeAboveProjects( .putIdentities(targetInputs) .build(); - if (!newTargetChildOutputs.containsAll(SymbolsExtractor.extractUnique(newAssignments.getExpressions()))) { + if (!newTargetChildOutputs.containsAll(SymbolsExtractor.extractUnique(newAssignments.getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toList())))) { // Projection uses an output of the target -- can't move the target above this projection. return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 1f85ddba6b2df..e5480afd4be46 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -97,7 +97,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont Expression filter = call.getFilter().get(); Symbol symbol = context.getSymbolAllocator().newSymbol(filter, BOOLEAN); verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern"); - newAssignments.put(symbol, filter); + newAssignments.put(symbol, castToRowExpression(filter)); mask = Optional.of(symbol); maskSymbols.add(symbol.toSymbolReference()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index f11491ca65ab7..2be5fb2d27aa1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -16,12 +16,14 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.TryExpression; @@ -38,6 +40,8 @@ import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static java.util.stream.Collectors.toSet; /** @@ -72,25 +76,27 @@ public Result apply(ProjectNode parent, Captures captures, Context context) // inline the expressions Assignments assignments = AssignmentsUtils.filter(child.getAssignments(), targets::contains); - Map parentAssignments = parent.getAssignments() + Map parentAssignments = parent.getAssignments() .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, - entry -> inlineReferences(entry.getValue(), assignments))); + entry -> castToRowExpression(inlineReferences(castToExpression(entry.getValue()), assignments)))); // Synthesize identity assignments for the inputs of expressions that were inlined // to place in the child projection. // If all assignments end up becoming identity assignments, they'll get pruned by // other rules Set inputs = child.getAssignments() - .entrySet().stream() + .entrySet() + .stream() .filter(entry -> targets.contains(entry.getKey())) .map(Map.Entry::getValue) + .map(OriginalExpressionUtils::castToExpression) .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) .collect(toSet()); AssignmentsUtils.Builder childAssignments = AssignmentsUtils.builder(); - for (Map.Entry assignment : child.getAssignments().entrySet()) { + for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { childAssignments.put(assignment); } @@ -112,12 +118,10 @@ public Result apply(ProjectNode parent, Captures captures, Context context) private Expression inlineReferences(Expression expression, Assignments assignments) { Function mapping = symbol -> { - Expression result = assignments.get(symbol); - if (result != null) { - return result; + if (assignments.get(symbol) == null) { + return symbol.toSymbolReference(); } - - return symbol.toSymbolReference(); + return castToExpression(assignments.get(symbol)); }; return inlineSymbols(mapping, expression); @@ -136,21 +140,23 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN Set childOutputSet = ImmutableSet.copyOf(child.getOutputSymbols()); Map dependencies = parent.getAssignments() - .getExpressions().stream() + .getExpressions() + .stream() + .map(OriginalExpressionUtils::castToExpression) .flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()) .filter(childOutputSet::contains) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); // find references to simple constants Set constants = dependencies.keySet().stream() - .filter(input -> child.getAssignments().get(input) instanceof Literal) + .filter(input -> castToExpression(child.getAssignments().get(input)) instanceof Literal) .collect(toSet()); // exclude any complex inputs to TRY expressions. Inlining them would potentially // change the semantics of those expressions Set tryArguments = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> extractTryArguments(expression).stream()) + .flatMap(expression -> extractTryArguments(castToExpression(expression)).stream()) .collect(toSet()); Set singletons = dependencies.entrySet().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java index 09a659b9a290a..2df0eb589c710 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import java.util.Optional; @@ -30,6 +31,7 @@ import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static java.util.stream.Collectors.toList; /** * @param The node type to look for under the ProjectNode @@ -60,7 +62,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { N targetNode = captures.get(targetCapture); - return pruneInputs(targetNode.getOutputSymbols(), parent.getAssignments().getExpressions()) + return pruneInputs(targetNode.getOutputSymbols(), parent.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toList())) .flatMap(prunedOutputs -> this.pushDownProjectOff(context.getIdAllocator(), targetNode, prunedOutputs)) .map(newChild -> parent.replaceChildren(ImmutableList.of(newChild))) .map(Result::ofPlanNode) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index ba6f9d7475137..aba709978cdbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -255,10 +255,10 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati AssignmentsUtils.Builder assignmentsBuilder = AssignmentsUtils.builder(); for (Symbol symbol : outerJoin.getOutputSymbols()) { if (aggregationNode.getAggregations().containsKey(symbol)) { - assignmentsBuilder.put(symbol, new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference())); + assignmentsBuilder.put(symbol, castToRowExpression(new CoalesceExpression(symbol.toSymbolReference(), sourceAggregationToOverNullMapping.get(symbol).toSymbolReference()))); } else { - assignmentsBuilder.put(symbol, symbol.toSymbolReference()); + assignmentsBuilder.put(symbol, castToRowExpression(symbol.toSymbolReference())); } } return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index f51954689ef3f..67d8db831ebd9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -52,6 +52,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -167,7 +168,7 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, for (Symbol output : aggregation.getOutputSymbols()) { Symbol input = symbolMapper.map(output); - assignments.put(output, input.toSymbolReference()); + assignments.put(output, castToRowExpression(input.toSymbolReference())); } partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 58ea2e8857e7f..d4e421ba10c66 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -24,6 +25,7 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; @@ -40,6 +42,8 @@ import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; /** * Transforms: @@ -97,13 +101,13 @@ public Result apply(ProjectNode project, Captures captures, Context context) .map(outputToInputMap::get) .forEach(nameReference -> { Symbol symbol = Symbol.from(nameReference); - projections.put(symbol, nameReference); + projections.put(symbol, castToRowExpression(nameReference)); inputs.add(symbol); }); if (exchange.getPartitioningScheme().getHashColumn().isPresent()) { // Need to retain the hash symbol for the exchange - projections.put(exchange.getPartitioningScheme().getHashColumn().get(), exchange.getPartitioningScheme().getHashColumn().get().toSymbolReference()); + projections.put(exchange.getPartitioningScheme().getHashColumn().get(), castToRowExpression(exchange.getPartitioningScheme().getHashColumn().get().toSymbolReference())); inputs.add(exchange.getPartitioningScheme().getHashColumn().get()); } @@ -115,16 +119,16 @@ public Result apply(ProjectNode project, Captures captures, Context context) .map(outputToInputMap::get) .forEach(nameReference -> { Symbol symbol = Symbol.from(nameReference); - projections.put(symbol, nameReference); + projections.put(symbol, castToRowExpression(nameReference)); inputs.add(symbol); }); } - for (Map.Entry projection : project.getAssignments().entrySet()) { - Expression translatedExpression = inlineSymbols(outputToInputMap, projection.getValue()); + for (Map.Entry projection : project.getAssignments().entrySet()) { + Expression translatedExpression = inlineSymbols(outputToInputMap, castToExpression(projection.getValue())); Type type = context.getSymbolAllocator().getTypes().get(projection.getKey()); Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); - projections.put(symbol, translatedExpression); + projections.put(symbol, castToRowExpression(translatedExpression)); inputs.add(symbol); } newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build())); @@ -140,7 +144,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) .filter(symbol -> !partitioningColumns.contains(symbol)) .forEach(outputBuilder::add); } - for (Map.Entry projection : project.getAssignments().entrySet()) { + for (Map.Entry projection : project.getAssignments().entrySet()) { outputBuilder.add(projection.getKey()); } @@ -167,7 +171,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) private static boolean isSymbolToSymbolProjection(ProjectNode project) { - return project.getAssignments().getExpressions().stream().allMatch(e -> e instanceof SymbolReference); + return project.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).allMatch(e -> e instanceof SymbolReference); } private static Map extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index d6a9b6b079934..7f1dcb942ba32 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Symbol; @@ -38,6 +39,8 @@ import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.union; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; public class PushProjectionThroughUnion implements Rule @@ -75,11 +78,11 @@ public Result apply(ProjectNode parent, Captures captures, Context context) Map projectSymbolMapping = new HashMap<>(); // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode - for (Map.Entry entry : parent.getAssignments().entrySet()) { - Expression translatedExpression = inlineSymbols(outputToInput, entry.getValue()); + for (Map.Entry entry : parent.getAssignments().entrySet()) { + Expression translatedExpression = inlineSymbols(outputToInput, castToExpression(entry.getValue())); Type type = context.getSymbolAllocator().getTypes().get(entry.getKey()); Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type); - assignments.put(symbol, translatedExpression); + assignments.put(symbol, castToRowExpression(translatedExpression)); projectSymbolMapping.put(entry.getKey(), symbol); } outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index 8aaf9c855c9be..d282ce9872f39 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -16,6 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.planner.AssignmentsUtils; @@ -38,6 +39,7 @@ import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -89,7 +91,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) { ImmutableMap.Builder aggregations = ImmutableMap.builder(); Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", INTEGER); - ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); + ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); for (Map.Entry entry : node.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); FunctionCall call = aggregation.getCall(); @@ -99,10 +101,10 @@ public Result apply(AggregationNode node, Captures captures, Context context) Expression geometry = getOnlyElement(call.getArguments()); Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", geometryType); if (geometry instanceof FunctionCall && ((FunctionCall) geometry).getName().toString().equalsIgnoreCase("ST_Envelope")) { - envelopeAssignments.put(envelopeSymbol, geometry); + envelopeAssignments.put(envelopeSymbol, castToRowExpression(geometry)); } else { - envelopeAssignments.put(envelopeSymbol, new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(geometry))); + envelopeAssignments.put(envelopeSymbol, castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(geometry)))); } aggregations.put(entry.getKey(), new Aggregation( @@ -123,7 +125,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getSource(), AssignmentsUtils.builder() .putIdentities(node.getSource().getOutputSymbols()) - .put(partitionCountSymbol, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) + .put(partitionCountSymbol, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))) .putAll(envelopeAssignments.build()) .build()), aggregations.build(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 8564296c4c372..15c38155b7e43 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -38,6 +38,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static java.util.Objects.requireNonNull; public class SimplifyCountOverConstant @@ -106,7 +107,7 @@ private static boolean isCountOverConstant(AggregationNode.Aggregation aggregati Expression argument = aggregation.getCall().getArguments().get(0); if (argument instanceof SymbolReference) { - argument = inputs.get(Symbol.from(argument)); + argument = castToExpression(inputs.get(Symbol.from(argument))); } return argument instanceof Literal && !(argument instanceof NullLiteral); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 9073f707f611f..bbb8ec0666aa2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -67,6 +67,7 @@ import static com.facebook.presto.sql.planner.plan.Patterns.Apply.correlation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -118,7 +119,7 @@ public Result apply(ApplyNode apply, Captures captures, Context context) if (subqueryAssignments.size() != 1) { return Result.empty(); } - Expression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions()); + Expression assignmentExpression = castToExpression(getOnlyElement(subqueryAssignments.getExpressions())); if (!(assignmentExpression instanceof InPredicate)) { return Result.empty(); } @@ -177,7 +178,7 @@ private PlanNode buildInPredicateEquivalent( decorrelatedBuildSource, AssignmentsUtils.builder() .putIdentities(decorrelatedBuildSource.getOutputSymbols()) - .put(buildSideKnownNonNull, bigint(0)) + .put(buildSideKnownNonNull, castToRowExpression(bigint(0))) .build()); Symbol probeSideSymbol = Symbol.from(inPredicate.getValue()); @@ -227,7 +228,7 @@ private PlanNode buildInPredicateEquivalent( aggregation, AssignmentsUtils.builder() .putIdentities(apply.getInput().getOutputSymbols()) - .put(inPredicateOutputSymbol, inPredicateEquivalent) + .put(inPredicateOutputSymbol, castToRowExpression(inPredicateEquivalent)) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 017de469904b0..fd76a719fd701 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -47,6 +47,8 @@ import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; import static com.google.common.base.Preconditions.checkState; @@ -101,7 +103,7 @@ public Result apply(ApplyNode parent, Captures captures, Context context) return Result.empty(); } - Expression expression = getOnlyElement(parent.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(parent.getSubqueryAssignments().getExpressions())); if (!(expression instanceof ExistsPredicate)) { return Result.empty(); } @@ -121,7 +123,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); assignments.putIdentities(applyNode.getInput().getOutputSymbols()); - assignments.put(exists, new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL))); + assignments.put(exists, castToRowExpression(new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL)))); PlanNode subquery = new ProjectNode( context.getIdAllocator().getNextId(), @@ -130,7 +132,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C applyNode.getSubquery(), 1L, false), - AssignmentsUtils.of(subqueryTrue, TRUE_LITERAL)); + AssignmentsUtils.of(subqueryTrue, castToRowExpression(TRUE_LITERAL))); PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()); if (!decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isPresent()) { @@ -170,7 +172,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), - AssignmentsUtils.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString())))), + AssignmentsUtils.of(exists, castToRowExpression(new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), BIGINT.toString()))))), parent.getCorrelation(), INNER, parent.getOriginSubqueryError()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index d844134e172da..b5c9956523056 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -27,6 +27,7 @@ import static com.facebook.presto.matching.Pattern.empty; import static com.facebook.presto.sql.planner.plan.Patterns.Apply.correlation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.google.common.collect.Iterables.getOnlyElement; /** @@ -71,7 +72,7 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) return Result.empty(); } - Expression expression = getOnlyElement(applyNode.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(applyNode.getSubqueryAssignments().getExpressions())); if (!(expression instanceof InPredicate)) { return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 44480f48f3e5d..59567953d7844 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; @@ -1311,9 +1312,9 @@ private Scope selectExchangeScopeForPartitionedRemoteExchange(PlanNode exchangeS private static Map computeIdentityTranslations(Assignments assignments) { Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); + for (Map.Entry assignment : assignments.getMap().entrySet()) { + if (castToExpression(assignment.getValue()) instanceof SymbolReference) { + outputToInput.put(assignment.getKey(), Symbol.from(castToExpression(assignment.getValue()))); } } return outputToInput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 1440f270ee333..44ecbeb885cea 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -17,6 +17,7 @@ import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.planner.AssignmentsUtils; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; @@ -79,6 +80,8 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -625,7 +628,7 @@ public PlanWithProperties visitProject(ProjectNode node, HashComputationSet pare else { hashExpression = hashSymbol.toSymbolReference(); } - newAssignments.put(hashSymbol, hashExpression); + newAssignments.put(hashSymbol, castToRowExpression(hashExpression)); allHashSymbols.put(hashComputation, hashSymbol); } @@ -722,7 +725,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo for (Symbol symbol : planWithProperties.getNode().getOutputSymbols()) { HashComputation partitionSymbols = resultHashSymbols.get(symbol); if (partitionSymbols == null || requiredHashes.getHashes().contains(partitionSymbols)) { - assignments.put(symbol, symbol.toSymbolReference()); + assignments.put(symbol, castToRowExpression(symbol.toSymbolReference())); if (partitionSymbols != null) { outputHashSymbols.put(partitionSymbols, symbol); @@ -735,7 +738,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo if (!planWithProperties.getHashSymbols().containsKey(hashComputation)) { Expression hashExpression = hashComputation.getHashExpression(); Symbol hashSymbol = symbolAllocator.newHashSymbol(); - assignments.put(hashSymbol, hashExpression); + assignments.put(hashSymbol, castToRowExpression(hashExpression)); outputHashSymbols.put(hashComputation, hashSymbol); } } @@ -969,12 +972,12 @@ public Symbol getRequiredHashSymbol(HashComputation hash) } } - private static Map computeIdentityTranslations(Map assignments) + private static Map computeIdentityTranslations(Map assignments) { Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), Symbol.from(assignment.getValue())); + for (Map.Entry assignment : assignments.entrySet()) { + if (castToExpression(assignment.getValue()) instanceof SymbolReference) { + outputToInput.put(assignment.getKey(), Symbol.from(castToExpression(assignment.getValue()))); } } return outputToInput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 26e58f71c9bcc..9075ca606b80c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -217,13 +217,13 @@ private PlanNode appendMarkers(PlanNode source, int markerIndex, List ma // add existing intersect symbols to projection for (Map.Entry entry : projections.entrySet()) { Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey())); - assignments.put(symbol, entry.getValue()); + assignments.put(symbol, castToRowExpression(entry.getValue())); } // add extra marker fields to the projection for (int i = 0; i < markers.size(); ++i) { Expression expression = (i == markerIndex) ? TRUE_LITERAL : new Cast(new NullLiteral(), StandardTypes.BOOLEAN); - assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), BOOLEAN), expression); + assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), BOOLEAN), castToRowExpression(expression)); } return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index b5fee93f6f588..2783325297099 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -327,8 +327,8 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) // Rewrite the lookup symbols in terms of only the pre-projected symbols that have direct translations Set newLookupSymbols = context.get().getLookupSymbols().stream() .map(node.getAssignments()::get) - .filter(SymbolReference.class::isInstance) - .map(Symbol::from) + .filter(value -> castToExpression(value) instanceof SymbolReference) + .map(value -> Symbol.from(castToExpression(value))) .collect(toImmutableSet()); if (newLookupSymbols.isEmpty()) { @@ -486,7 +486,9 @@ protected Map visitPlan(PlanNode node, Set lookupSymbols public Map visitProject(ProjectNode node, Set lookupSymbols) { // Map from output Symbols to source Symbols - Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); + Map directSymbolTranslationOutputMap = Maps.transformValues( + Maps.filterValues(node.getAssignments().getMap(), value -> (castToExpression(value) instanceof SymbolReference)), + value -> Symbol.from(castToExpression(value))); Map outputToSourceMap = lookupSymbols.stream() .filter(directSymbolTranslationOutputMap.keySet()::contains) .collect(toImmutableMap(identity(), directSymbolTranslationOutputMap::get)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index aade31a0c6e69..b27919ff3864f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -45,6 +45,7 @@ import com.facebook.presto.sql.planner.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -57,6 +58,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; /** * Converts cardinality-insensitive aggregations (max, min, "distinct") over partition keys @@ -201,7 +203,7 @@ private static Optional findTableScan(PlanNode source) else if (source instanceof ProjectNode) { // verify projections are deterministic ProjectNode project = (ProjectNode) source; - if (!Iterables.all(project.getAssignments().getExpressions(), DeterminismEvaluator::isDeterministic)) { + if (!Iterables.all(project.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toList()), DeterminismEvaluator::isDeterministic)) { return Optional.empty(); } source = project.getSource(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 6086e10615395..f661e9839160d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -56,6 +56,7 @@ import static com.facebook.presto.SystemSessionProperties.isOptimizeDistinctAggregationEnabled; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -204,7 +205,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) public PlanNode visitProject(ProjectNode node, RewriteContext context) { Set deterministicSymbols = node.getAssignments().entrySet().stream() - .filter(entry -> DeterminismEvaluator.isDeterministic(entry.getValue())) + .filter(entry -> DeterminismEvaluator.isDeterministic(castToExpression(entry.getValue()))) .map(Map.Entry::getKey) .collect(Collectors.toSet()); @@ -259,7 +261,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() - .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) + .map(entry -> inlineSymbols(transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression), entry)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); @@ -293,7 +295,7 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); return dependencies.entrySet().stream() - .allMatch(entry -> entry.getValue() == 1 || node.getAssignments().get(entry.getKey()) instanceof Literal); + .allMatch(entry -> entry.getValue() == 1 || castToExpression(node.getAssignments().get(entry.getKey())) instanceof Literal); } @Override @@ -346,7 +348,7 @@ public PlanNode visitUnion(UnionNode node, RewriteContext context) boolean modified = false; ImmutableList.Builder builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { - Expression sourcePredicate = inlineSymbols(Maps.transformValues(node.sourceSymbolMap(i), Symbol::toSymbolReference), context.get()); + Expression sourcePredicate = inlineSymbols(transformValues(node.sourceSymbolMap(i), Symbol::toSymbolReference), context.get()); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); if (rewrittenSource != source) { @@ -454,14 +456,10 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Create identity projections for all existing symbols AssignmentsUtils.Builder leftProjections = AssignmentsUtils.builder(); - leftProjections.putAll(node.getLeft() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + leftProjections.putAll(identity(node.getLeft().getOutputSymbols())); AssignmentsUtils.Builder rightProjections = AssignmentsUtils.builder(); - rightProjections.putAll(node.getRight() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + rightProjections.putAll(identity(node.getRight().getOutputSymbols())); // Create new projections for the new join clauses List equiJoinClauses = new ArrayList<>(); @@ -476,12 +474,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) Symbol leftSymbol = symbolForExpression(leftExpression); if (!node.getLeft().getOutputSymbols().contains(leftSymbol)) { - leftProjections.put(leftSymbol, leftExpression); + leftProjections.put(leftSymbol, castToRowExpression(leftExpression)); } Symbol rightSymbol = symbolForExpression(rightExpression); if (!node.getRight().getOutputSymbols().contains(rightSymbol)) { - rightProjections.put(rightSymbol, rightExpression); + rightProjections.put(rightSymbol, castToRowExpression(rightExpression)); } equiJoinClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); @@ -537,7 +535,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } if (!node.getOutputSymbols().equals(output.getOutputSymbols())) { - output = new ProjectNode(idAllocator.getNextId(), output, AssignmentsUtils.identity(node.getOutputSymbols())); + output = new ProjectNode(idAllocator.getNextId(), output, identity(node.getOutputSymbols())); } return output; @@ -603,14 +601,10 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext key, Symbol::toSymbolReference))); + leftProjections.putAll(identity(node.getLeft().getOutputSymbols())); AssignmentsUtils.Builder rightProjections = AssignmentsUtils.builder(); - rightProjections.putAll(node.getRight() - .getOutputSymbols().stream() - .collect(Collectors.toMap(key -> key, Symbol::toSymbolReference))); + rightProjections.putAll(identity(node.getRight().getOutputSymbols())); leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 15fda64fa93cb..e74b0813bbc1f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -26,12 +26,15 @@ import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.predicate.NullableValue; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.NoOpSymbolResolver; import com.facebook.presto.sql.planner.OrderingScheme; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.ActualProperties.Global; @@ -624,28 +627,49 @@ public ActualProperties visitProject(ProjectNode node, List in // Extract additional constants Map constants = new HashMap<>(); - for (Map.Entry assignment : node.getAssignments().entrySet()) { - Expression expression = assignment.getValue(); - - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); - Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); - ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); - // TODO: - // We want to use a symbol resolver that looks up in the constants from the input subplan - // to take advantage of constant-folding for complex expressions - // However, that currently causes errors when those expressions operate on arrays or row types - // ("ROW comparison not supported for fields with null elements", etc) - Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); - - if (value instanceof SymbolReference) { - Symbol symbol = Symbol.from((SymbolReference) value); - NullableValue existingConstantValue = constants.get(symbol); - if (existingConstantValue != null) { + for (Map.Entry assignment : node.getAssignments().entrySet()) { + RowExpression expression = assignment.getValue(); + + if (isExpression(expression)) { + Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, castToExpression(expression), emptyList(), WarningCollector.NOOP); + Type type = requireNonNull(expressionTypes.get(NodeRef.of(castToExpression(expression)))); + ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(castToExpression(expression), metadata, session, expressionTypes); + // TODO: + // We want to use a symbol resolver that looks up in the constants from the input subplan + // to take advantage of constant-folding for complex expressions + // However, that currently causes errors when those expressions operate on arrays or row types + // ("ROW comparison not supported for fields with null elements", etc) + Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); + + if (value instanceof SymbolReference) { + Symbol symbol = Symbol.from((SymbolReference) value); + NullableValue existingConstantValue = constants.get(symbol); + if (existingConstantValue != null) { + constants.put(assignment.getKey(), new NullableValue(type, value)); + } + } + else if (!(value instanceof Expression)) { constants.put(assignment.getKey(), new NullableValue(type, value)); } } - else if (!(value instanceof Expression)) { - constants.put(assignment.getKey(), new NullableValue(type, value)); + else { + // TODO: + // We want to use a symbol resolver that looks up in the constants from the input subplan + // to take advantage of constant-folding for complex expressions + // However, that currently causes errors when those expressions operate on arrays or row types + // ("ROW comparison not supported for fields with null elements", etc) + Object value = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), true).optimize(); + + if (value instanceof VariableReferenceExpression) { + Symbol symbol = new Symbol(((VariableReferenceExpression) value).getName()); + NullableValue existingConstantValue = constants.get(symbol); + if (existingConstantValue != null) { + constants.put(assignment.getKey(), new NullableValue(((VariableReferenceExpression) value).getType(), value)); + } + } + else if (!(value instanceof RowExpression)) { + constants.put(assignment.getKey(), new NullableValue(expression.getType(), value)); + } } } constants.putAll(translatedProperties.getConstants()); @@ -776,12 +800,20 @@ private static Optional> translateToNonConstantSymbols( return Optional.of(ImmutableList.copyOf(builder.build())); } - private static Map computeIdentityTranslations(Map assignments) + private static Map computeIdentityTranslations(Map assignments) { Map inputToOutput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + inputToOutput.put(Symbol.from(castToExpression(expression)), assignment.getKey()); + } + } + else { + if (expression instanceof VariableReferenceExpression) { + inputToOutput.put(new Symbol(((VariableReferenceExpression) expression).getName()), assignment.getKey()); + } } } return inputToOutput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 5c9b82264726b..8dba68d8ad465 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -87,6 +87,8 @@ import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.concat; @@ -528,7 +530,12 @@ public PlanNode visitProject(ProjectNode node, RewriteContext> conte AssignmentsUtils.Builder builder = AssignmentsUtils.builder(); node.getAssignments().forEach((symbol, expression) -> { if (context.get().contains(symbol)) { - expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); + if (isExpression(expression)) { + expectedInputs.addAll(SymbolsExtractor.extractUnique(castToExpression(expression))); + } + else { + expectedInputs.addAll(SymbolsExtractor.extractUnique(expression)); + } builder.put(symbol, expression); } }); @@ -779,12 +786,12 @@ public PlanNode visitApply(ApplyNode node, RewriteContext> context) // extract symbols required subquery plan ImmutableSet.Builder subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder(); AssignmentsUtils.Builder subqueryAssignments = AssignmentsUtils.builder(); - for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { + for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { Symbol output = entry.getKey(); - Expression expression = entry.getValue(); + Expression expression = castToExpression(entry.getValue()); if (context.get().contains(output)) { subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression)); - subqueryAssignments.put(output, expression); + subqueryAssignments.put(output, castToRowExpression(expression)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index 44f281e091f09..c1dcb8441321a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -47,6 +47,7 @@ import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -82,7 +83,7 @@ public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, Aggreg Symbol nonNull = symbolAllocator.newSymbol("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = AssignmentsUtils.builder() .putIdentities(source.get().getNode().getOutputSymbols()) - .put(nonNull, TRUE_LITERAL) + .put(nonNull, castToRowExpression(TRUE_LITERAL)) .build(); ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode( idAllocator.getNextId(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 81ec1b676e6b1..c29243efd07f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -18,6 +18,8 @@ import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.Symbol; @@ -57,7 +59,6 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; @@ -82,6 +83,8 @@ import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.FIXED; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.MULTIPLE; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.SINGLE; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -336,12 +339,20 @@ public StreamProperties visitProject(ProjectNode node, List in return properties.translate(column -> Optional.ofNullable(identities.get(column))); } - private static Map computeIdentityTranslations(Map assignments) + private static Map computeIdentityTranslations(Map assignments) { Map inputToOutput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - inputToOutput.put(Symbol.from(assignment.getValue()), assignment.getKey()); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + inputToOutput.put(Symbol.from(castToExpression(expression)), assignment.getKey()); + } + } + else { + if (expression instanceof VariableReferenceExpression) { + inputToOutput.put(new Symbol(((VariableReferenceExpression) expression).getName()), assignment.getKey()); + } } } return inputToOutput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index decebd029cfbc..529e12b85ca33 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -58,6 +58,8 @@ import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; @@ -117,7 +119,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) return context.defaultRewrite(node); } - Expression expression = getOnlyElement(node.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(node.getSubqueryAssignments().getExpressions())); if (!(expression instanceof QuantifiedComparisonExpression)) { return context.defaultRewrite(node); } @@ -181,7 +183,7 @@ countNonNullValue, new Aggregation( Symbol quantifiedComparisonSymbol = getOnlyElement(node.getSubqueryAssignments().getSymbols()); - return projectExpressions(lateralJoinNode, AssignmentsUtils.of(quantifiedComparisonSymbol, valueComparedToSubquery)); + return projectExpressions(lateralJoinNode, AssignmentsUtils.of(quantifiedComparisonSymbol, castToRowExpression(valueComparedToSubquery))); } public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java index 61b89d9b63d2e..8e2c506cdaaf4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java @@ -20,8 +20,13 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.AssignmentsUtils; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; @@ -32,11 +37,14 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.execution.warnings.WarningCollector.NOOP; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.planner.plan.Patterns.filter; +import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.values; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; @@ -59,7 +67,59 @@ public Set> rules() // TODO: finish all other PlanNodes that have Expression return ImmutableSet.of( new ValuesExpressionTranslation(), - new FilterExpressionTranslation()); + new FilterExpressionTranslation(), + new ProjectExpressionTranslation(), + new ApplyExpressionTranslation()); + } + + private final class ProjectExpressionTranslation + implements Rule + { + @Override + public Pattern getPattern() + { + return project(); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + Assignments assignments = projectNode.getAssignments(); + Optional rewrittenAssignments = translateAssignments(assignments, context); + + if (!rewrittenAssignments.isPresent()) { + return Result.empty(); + } + return Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), rewrittenAssignments.get())); + } + } + + private final class ApplyExpressionTranslation + implements Rule + { + @Override + public Pattern getPattern() + { + return applyNode(); + } + + @Override + public Result apply(ApplyNode applyNode, Captures captures, Context context) + { + Assignments assignments = applyNode.getSubqueryAssignments(); + Optional rewrittenAssignments = translateAssignments(assignments, context); + + if (!rewrittenAssignments.isPresent()) { + return Result.empty(); + } + return Result.ofPlanNode(new ApplyNode( + applyNode.getId(), + applyNode.getInput(), + applyNode.getSubquery(), + rewrittenAssignments.get(), + applyNode.getCorrelation(), + applyNode.getOriginSubqueryError())); + } } private final class FilterExpressionTranslation @@ -139,4 +199,30 @@ private RowExpression toRowExpression(Expression expression, Rule.Context contex return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionManager(), metadata.getTypeManager(), context.getSession(), false); } + + /** + * Return Optional.empty() to denote unchanged assignments + */ + private Optional translateAssignments(Assignments assignments, Rule.Context context) + { + AssignmentsUtils.Builder builder = AssignmentsUtils.builder(); + boolean anyRewritten = false; + for (Map.Entry entry : assignments.entrySet()) { + RowExpression expression = entry.getValue(); + RowExpression rewritten; + if (isExpression(expression)) { + rewritten = toRowExpression(castToExpression(expression), context); + anyRewritten = true; + } + else { + rewritten = expression; + } + builder.put(entry.getKey(), rewritten); + } + if (!anyRewritten) { + return Optional.empty(); + } + + return Optional.of(builder.build()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index bc30dd64064d2..6c31135a25610 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -603,8 +603,8 @@ private Assignments canonicalize(Assignments oldAssignments) { Map computedExpressions = new HashMap<>(); AssignmentsUtils.Builder assignments = AssignmentsUtils.builder(); - for (Map.Entry entry : oldAssignments.getMap().entrySet()) { - Expression expression = canonicalize(entry.getValue()); + for (Map.Entry entry : oldAssignments.getMap().entrySet()) { + Expression expression = canonicalize(castToExpression(entry.getValue())); if (expression instanceof SymbolReference) { // Always map a trivial symbol projection @@ -629,7 +629,7 @@ else if (DeterminismEvaluator.isDeterministic(expression) && !(expression instan } Symbol canonical = canonicalize(entry.getKey()); - assignments.put(canonical, expression); + assignments.put(canonical, castToRowExpression(expression)); } return assignments.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java index 10cae0573a1e9..74562d462bc58 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java @@ -22,6 +22,7 @@ import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.PlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; @@ -39,6 +40,7 @@ import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -281,7 +283,7 @@ public JoinGraph visitProject(ProjectNode node, Context context) { if (isIdentity(node)) { JoinGraph graph = node.getSource().accept(this, context); - return graph.withAssignments(node.getAssignments().getMap()); + return graph.withAssignments(transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression)); } return visitPlan(node, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 634e37e4baee4..de74951a937b0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.tree.Expression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -31,10 +31,10 @@ public class Assignments { - private final Map assignments; + private final Map assignments; @JsonCreator - public Assignments(@JsonProperty("assignments") Map assignments) + public Assignments(@JsonProperty("assignments") Map assignments) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); } @@ -45,12 +45,12 @@ public List getOutputs() } @JsonProperty("assignments") - public Map getMap() + public Map getMap() { return assignments; } - public Collection getExpressions() + public Collection getExpressions() { return assignments.values(); } @@ -60,12 +60,12 @@ public Set getSymbols() return assignments.keySet(); } - public Set> entrySet() + public Set> entrySet() { return assignments.entrySet(); } - public Expression get(Symbol symbol) + public RowExpression get(Symbol symbol) { return assignments.get(symbol); } @@ -80,7 +80,7 @@ public boolean isEmpty() return size() == 0; } - public void forEach(BiConsumer consumer) + public void forEach(BiConsumer consumer) { assignments.forEach(consumer); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 882e78d7b26d6..4a947b02ee923 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -33,6 +33,7 @@ import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.InterpretedFunctionInvoker; import com.facebook.presto.sql.planner.OrderingScheme; @@ -90,7 +91,6 @@ import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.util.GraphvizPrinter; import com.google.common.base.CaseFormat; import com.google.common.base.Functions; @@ -1030,12 +1030,12 @@ private Void processChildren(PlanNode node, Void context) private void printAssignments(NodeRepresentation nodeOutput, Assignments assignments) { - for (Map.Entry entry : assignments.getMap().entrySet()) { - if (entry.getValue() instanceof SymbolReference && ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + for (Map.Entry entry : assignments.getMap().entrySet()) { + if (entry.getValue() instanceof VariableReferenceExpression && ((VariableReferenceExpression) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments continue; } - nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), entry.getValue()); + nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), formatter.formatRowExpression(entry.getValue())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index 9c2eb0bde5182..afd70beb26ff2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -17,6 +17,7 @@ import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; @@ -40,6 +41,8 @@ import java.util.Map; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Collections.emptyList; @@ -112,16 +115,23 @@ public Void visitProject(ProjectNode node, Void context) { visitPlan(node, context); - for (Map.Entry entry : node.getAssignments().entrySet()) { + for (Map.Entry entry : node.getAssignments().entrySet()) { Type expectedType = types.get(entry.getKey()); - if (entry.getValue() instanceof SymbolReference) { - SymbolReference symbolReference = (SymbolReference) entry.getValue(); - verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); - continue; + RowExpression expression = entry.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + SymbolReference symbolReference = (SymbolReference) castToExpression(expression); + verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), types.get(Symbol.from(symbolReference)).getTypeSignature()); + continue; + } + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, castToExpression(expression), emptyList(), warningCollector); + Type actualType = expressionTypes.get(NodeRef.of(castToExpression(expression))); + verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); + } + else { + Type actualType = expression.getType(); + verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); - verifyTypeSignature(entry.getKey(), expectedType.getTypeSignature(), actualType.getTypeSignature()); } return null; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index db21d51ef8e95..56d64d1d3bae5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; @@ -62,7 +63,6 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -258,8 +258,14 @@ public Void visitProject(ProjectNode node, Set boundSymbols) source.accept(this, boundSymbols); // visit child Set inputs = createInputs(source, boundSymbols); - for (Expression expression : node.getAssignments().getExpressions()) { - Set dependencies = SymbolsExtractor.extractUnique(expression); + for (RowExpression expression : node.getAssignments().getExpressions()) { + Set dependencies; + if (isExpression(expression)) { + dependencies = SymbolsExtractor.extractUnique(castToExpression(expression)); + } + else { + dependencies = SymbolsExtractor.extractUnique(expression); + } checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -626,8 +632,14 @@ public Void visitApply(ApplyNode node, Set boundSymbols) .addAll(createInputs(node.getInput(), boundSymbols)) .build(); - for (Expression expression : node.getSubqueryAssignments().getExpressions()) { - Set dependencies = SymbolsExtractor.extractUnique(expression); + for (RowExpression expression : node.getSubqueryAssignments().getExpressions()) { + Set dependencies; + if (isExpression(expression)) { + dependencies = SymbolsExtractor.extractUnique(castToExpression(expression)); + } + else { + dependencies = SymbolsExtractor.extractUnique(expression); + } checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java index 8a6c2fa040a3b..8daa4100830b9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -20,14 +21,16 @@ import java.util.Map; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; + public class ProjectNodeUtils { private ProjectNodeUtils() {} public static boolean isIdentity(ProjectNode projectNode) { - for (Map.Entry entry : projectNode.getAssignments().entrySet()) { - Expression expression = entry.getValue(); + for (Map.Entry entry : projectNode.getAssignments().entrySet()) { + Expression expression = castToExpression(entry.getValue()); Symbol symbol = entry.getKey(); if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { return false; diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index fe522c29d6332..015110e707cbd 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -14,6 +14,8 @@ package com.facebook.presto.util; import com.facebook.presto.Session; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.SubPlan; @@ -57,7 +59,6 @@ import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; @@ -378,9 +379,9 @@ public Void visitFilter(FilterNode node, Void context) public Void visitProject(ProjectNode node, Void context) { StringBuilder builder = new StringBuilder(); - for (Map.Entry entry : node.getAssignments().entrySet()) { - if ((entry.getValue() instanceof SymbolReference) && - ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + for (Map.Entry entry : node.getAssignments().entrySet()) { + if ((entry.getValue() instanceof VariableReferenceExpression) && + ((VariableReferenceExpression) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments continue; } From ae2449e23c9d3d104cd1625614e978c9f1802e54 Mon Sep 17 00:00:00 2001 From: James Sun Date: Tue, 23 Apr 2019 00:23:32 -0700 Subject: [PATCH 4/6] workaround in order not to pass compilation for tests --- .../presto/sql/planner/AssignmentsUtils.java | 15 ++++++++ .../planner/assertions/ExpressionMatcher.java | 34 +++++++++++++------ .../sql/planner/assertions/SymbolAliases.java | 16 +++++++-- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java index eda555860963f..2ea1389ca737c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/AssignmentsUtils.java @@ -73,11 +73,21 @@ public static Assignments of(Symbol symbol, RowExpression expression) return builder().put(symbol, expression).build(); } + public static Assignments of(Symbol symbol, Expression expression) + { + return builder().put(symbol, castToRowExpression(expression)).build(); + } + public static Assignments of(Symbol symbol1, RowExpression expression1, Symbol symbol2, RowExpression expression2) { return builder().put(symbol1, expression1).put(symbol2, expression2).build(); } + public static Assignments of(Symbol symbol1, Expression expression1, Symbol symbol2, Expression expression2) + { + return builder().put(symbol1, castToRowExpression(expression1)).put(symbol2, castToRowExpression(expression2)).build(); + } + // Originally, the following functions are not static move assignments as member variables public static Assignments rewrite(Assignments assignments, ExpressionRewriter rewriter) { @@ -155,6 +165,11 @@ public Builder put(Symbol symbol, RowExpression expression) return this; } + public Builder put(Symbol symbol, Expression expression) + { + return put(symbol, castToRowExpression(expression)); + } + public Builder put(Map.Entry assignment) { put(assignment.getKey(), assignment.getValue()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java index 83b363b1a303d..f41419caf9a38 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -29,6 +30,8 @@ import java.util.stream.Collectors; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; @@ -54,29 +57,38 @@ private Expression expression(String sql) public Optional getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { Optional result = Optional.empty(); - ImmutableList.Builder matchesBuilder = ImmutableList.builder(); - Map assignments = getAssignments(node); + ImmutableList.Builder matchesBuilder = ImmutableList.builder(); + Map assignments = getAssignments(node); if (assignments == null) { return result; } - ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); - - for (Map.Entry assignment : assignments.entrySet()) { - if (verifier.process(assignment.getValue(), expression)) { - result = Optional.of(assignment.getKey()); - matchesBuilder.add(assignment.getValue()); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression rightValue = assignment.getValue(); + if (isExpression(rightValue)) { + ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); + if (verifier.process(castToExpression(rightValue), expression)) { + result = Optional.of(assignment.getKey()); + matchesBuilder.add(castToExpression(rightValue)); + } + } + else { + RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session); + if (verifier.process(expression, rightValue)) { + result = Optional.of(assignment.getKey()); + matchesBuilder.add(rightValue); + } } } - List matches = matchesBuilder.build(); + List matches = matchesBuilder.build(); checkState(matches.size() < 2, "Ambiguous expression %s matches multiple assignments", expression, - (matches.stream().map(Expression::toString).collect(Collectors.joining(", ")))); + (matches.stream().map(Object::toString).collect(Collectors.joining(", ")))); return result; } - private static Map getAssignments(PlanNode node) + private static Map getAssignments(PlanNode node) { if (node instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) node; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java index 1ec113d359b06..7625794c72ac4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java @@ -13,9 +13,10 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; @@ -24,6 +25,8 @@ import java.util.Map; import java.util.Optional; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; @@ -97,9 +100,16 @@ private static String toKey(String alias) private Map getUpdatedAssignments(Assignments assignments) { ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { + for (Map.Entry assignment : assignments.getMap().entrySet()) { for (Map.Entry existingAlias : map.entrySet()) { - if (assignment.getValue().equals(existingAlias.getValue())) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression) && castToExpression(expression).equals(existingAlias.getValue())) { + // Simple symbol rename + mapUpdate.put(existingAlias.getKey(), assignment.getKey().toSymbolReference()); + } + else if (!isExpression(expression) && + (expression instanceof VariableReferenceExpression) && + new SymbolReference(((VariableReferenceExpression) expression).getName()).equals(existingAlias.getValue())) { // Simple symbol rename mapUpdate.put(existingAlias.getKey(), assignment.getKey().toSymbolReference()); } From 150be12856fdbed5830fa77a661b0a7b6904aa71 Mon Sep 17 00:00:00 2001 From: James Sun Date: Wed, 24 Apr 2019 21:51:09 -0700 Subject: [PATCH 5/6] Make ProjectNodeUtil to handle RowExpression --- .../presto/sql/relational/ProjectNodeUtils.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java index 8daa4100830b9..38cb1062a1dad 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.relational; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -22,6 +23,7 @@ import java.util.Map; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; public class ProjectNodeUtils { @@ -30,10 +32,18 @@ private ProjectNodeUtils() {} public static boolean isIdentity(ProjectNode projectNode) { for (Map.Entry entry : projectNode.getAssignments().entrySet()) { - Expression expression = castToExpression(entry.getValue()); + RowExpression value = entry.getValue(); Symbol symbol = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { - return false; + if (isExpression(value)) { + Expression expression = castToExpression(value); + if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { + return false; + } + } + else { + if (!(value instanceof VariableReferenceExpression && ((VariableReferenceExpression) value).getName().equals(symbol.getName()))) { + return false; + } } } return true; From 0dd0d9dd3f920c07ee54c70670438217f9f75c6b Mon Sep 17 00:00:00 2001 From: James Sun Date: Thu, 25 Apr 2019 17:34:28 -0700 Subject: [PATCH 6/6] fix graphviz printer --- .../src/main/java/com/facebook/presto/util/GraphvizPrinter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 015110e707cbd..5225b5eaa4dd8 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -385,7 +385,7 @@ public Void visitProject(ProjectNode node, Void context) // skip identity assignments continue; } - builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue())); + builder.append(format("%s := %s\\n", entry.getKey(), formatter.formatRowExpression(entry.getValue()))); } printNode(node, "Project", builder.toString(), NODE_COLORS.get(NodeType.PROJECT));