From 9719fe555d57d069ccd497b74b394327dccb45a4 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:38 -0700 Subject: [PATCH 01/11] Allow type coercion in DependencyChecker checkDependencies checks if referenced variables in assignment are same as input. If the assignment is type only cast (for example cast varchar(3) to varchar(5)), the referenced variable will be implicitly changed to target type. Reading symbol1, VARCHAR(3) input into symbol1 VARCHAR(5) still should be OK. Simple containsAll works when input is considered as symbol will fail when we add in the type checks. --- .../sanity/ValidateDependenciesChecker.java | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 9750861edd3b5..eb2deda36ecfe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.TypeProvider; @@ -92,22 +93,24 @@ public final class ValidateDependenciesChecker @Override public void validate(PlanNode plan, Session session, Metadata metadata, SqlParser sqlParser, TypeProvider types, WarningCollector warningCollector) { - validate(plan, types); + validate(plan, types, metadata.getTypeManager()); } - public static void validate(PlanNode plan, TypeProvider types) + public static void validate(PlanNode plan, TypeProvider types, TypeManager typeManager) { - plan.accept(new Visitor(types), ImmutableSet.of()); + plan.accept(new Visitor(types, typeManager), ImmutableSet.of()); } private static class Visitor extends InternalPlanVisitor> { private final TypeProvider types; + private final TypeManager typeManager; - public Visitor(TypeProvider types) + public Visitor(TypeProvider types, TypeManager typeManager) { this.types = requireNonNull(types, "types is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); } @Override @@ -709,10 +712,17 @@ private static ImmutableSet createInputs(PlanNode s .addAll(boundVariables) .build(); } - } - private static void checkDependencies(Collection inputs, Collection required, String message, Object... parameters) - { - checkArgument(ImmutableSet.copyOf(inputs).containsAll(required), message, parameters); + private void checkDependencies(Collection inputs, Collection required, String message, Object... parameters) + { + for (VariableReferenceExpression target : required) { + checkArgument( + inputs.stream() + .anyMatch(input -> input.getName().equalsIgnoreCase(target.getName()) && + typeManager.isTypeOnlyCoercion(input.getType(), target.getType())), + message, + parameters); + } + } } } From 0abe62436c5c207be8875d162fd2345255ecef7e Mon Sep 17 00:00:00 2001 From: James Sun Date: Mon, 17 Jun 2019 17:24:39 -0700 Subject: [PATCH 02/11] Extract ApplyNode::isSupportedSubqueryExpression to utility --- .../presto/sql/planner/SubqueryPlanner.java | 2 + .../rule/ExpressionRewriteRuleSet.java | 2 + .../planner/optimizations/ApplyNodeUtil.java | 42 +++++++++++++++++++ .../PruneUnreferencedOutputs.java | 5 ++- .../UnaliasSymbolReferences.java | 5 ++- .../presto/sql/planner/plan/ApplyNode.java | 14 +------ .../iterative/rule/test/PlanBuilder.java | 2 + 7 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index ca124d2dbc7f6..815a0116d99db 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -61,6 +61,7 @@ import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.analyzer.SemanticExceptions.subQueryNotSupportedError; import static com.facebook.presto.sql.planner.ExpressionNodeInliner.replaceExpression; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -441,6 +442,7 @@ private PlanBuilder appendApplyNode(PlanBuilder subPlan, Node subquery, PlanNode TranslationMap translations = subPlan.copyTranslations(); PlanNode root = subPlan.getRoot(); + verifySubquerySupported(subqueryAssignments); return new PlanBuilder(translations, new ApplyNode(idAllocator.getNextId(), root, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index ee9b8152bcb70..61d2c20633e3a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -38,6 +38,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.planner.plan.Patterns.filter; @@ -328,6 +329,7 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) if (applyNode.getSubqueryAssignments().equals(subqueryAssignments)) { return Result.empty(); } + verifySubquerySupported(subqueryAssignments); return Result.ofPlanNode(new ApplyNode( applyNode.getId(), applyNode.getInput(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java new file mode 100644 index 0000000000000..348871a467fa3 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.tree.ExistsPredicate; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.InPredicate; +import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; + +import static com.google.common.base.Preconditions.checkArgument; + +public final class ApplyNodeUtil +{ + private ApplyNodeUtil() {} + + public static void verifySubquerySupported(Assignments assignments) + { + checkArgument( + assignments.getExpressions().stream().allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), + "Unexpected expression used for subquery expression"); + } + + public static boolean isSupportedSubqueryExpression(Expression expression) + { + // TODO: add RowExpression support + return expression instanceof InPredicate || + expression instanceof ExistsPredicate || + expression instanceof QuantifiedComparisonExpression; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index bfd7c20df0ff9..9ac34447c8a27 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -82,6 +82,7 @@ import java.util.stream.Collectors; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; @@ -818,7 +819,9 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) PlanNode subquery = context.rewrite(node.getSubquery()); List canonicalCorrelation = Lists.transform(node.getCorrelation(), this::canonicalize); - return new ApplyNode(node.getId(), source, subquery, canonicalize(node.getSubqueryAssignments()), canonicalCorrelation, node.getOriginSubqueryError()); + Assignments assignments = canonicalize(node.getSubqueryAssignments()); + verifySubquerySupported(assignments); + return new ApplyNode(node.getId(), source, subquery, assignments, canonicalCorrelation, node.getOriginSubqueryError()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index e05c6ca163103..06d9d4ec4217a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -16,10 +16,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.tree.ExistsPredicate; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.InPredicate; -import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; +import com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -87,7 +84,7 @@ public ApplyNode( checkArgument(input.getOutputVariables().containsAll(correlation), "Input does not contain symbols from correlation"); checkArgument( - subqueryAssignments.getExpressions().stream().allMatch(ApplyNode::isSupportedSubqueryExpression), + subqueryAssignments.getExpressions().stream().allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), "Unexpected expression used for subquery expression"); this.input = input; @@ -97,13 +94,6 @@ public ApplyNode( this.originSubqueryError = originSubqueryError; } - private static boolean isSupportedSubqueryExpression(Expression expression) - { - return expression instanceof InPredicate || - expression instanceof ExistsPredicate || - expression instanceof QuantifiedComparisonExpression; - } - @JsonProperty public PlanNode getInput() { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 8f4a6809efc09..3a0b3afaf60d2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -98,6 +98,7 @@ import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.constantNull; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -378,6 +379,7 @@ protected AggregationNode build() public ApplyNode apply(Assignments subqueryAssignments, List correlation, PlanNode input, PlanNode subquery) { + verifySubquerySupported(subqueryAssignments); return new ApplyNode(idAllocator.getNextId(), input, subquery, subqueryAssignments, correlation, ""); } From f3eb1706a67de2980eb3e05d7a8b4607fdd1d0f4 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:40 -0700 Subject: [PATCH 03/11] Translate row expresion before fragment it in TestCostCalculator --- .../presto/cost/TestCostCalculator.java | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 1c521547b0552..5b0cf18bcce7d 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -39,9 +39,13 @@ import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.PlanFragmenter; +import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.SubPlan; import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; +import com.facebook.presto.sql.planner.optimizations.TranslateExpressions; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; @@ -81,6 +85,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; import static com.facebook.presto.sql.planner.plan.ExchangeNode.replicatedExchange; import static com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange; +import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; @@ -417,13 +422,14 @@ public void testAggregation() @Test public void testRepartitionedJoinWithExchange() { - TableScanNode ts1 = tableScan("ts1", ImmutableList.of(new VariableReferenceExpression("orderkey", BIGINT))); - TableScanNode ts2 = tableScan("ts2", ImmutableList.of(new VariableReferenceExpression("orderkey_0", BIGINT))); + TableScanNode ts1 = tableScan("ts1", "orderkey"); + TableScanNode ts2 = tableScan("ts2", "orderkey_0"); + PlanNode p1 = project("p1", ts1, variable("orderkey_1", BIGINT), new SymbolReference("orderkey")); ExchangeNode remoteExchange1 = systemPartitionedExchange( new PlanNodeId("re1"), REMOTE_STREAMING, - ts1, - ImmutableList.of(new VariableReferenceExpression("orderkey", BIGINT)), + p1, + ImmutableList.of(new VariableReferenceExpression("orderkey_1", BIGINT)), Optional.empty()); ExchangeNode remoteExchange2 = systemPartitionedExchange( new PlanNodeId("re2"), @@ -442,7 +448,7 @@ public void testRepartitionedJoinWithExchange() remoteExchange1, localExchange, JoinNode.DistributionType.PARTITIONED, - "orderkey", + "orderkey_1", "orderkey_0"); Map stats = ImmutableMap.builder() @@ -450,11 +456,13 @@ public void testRepartitionedJoinWithExchange() .put("re1", statsEstimate(remoteExchange1, 10000)) .put("re2", statsEstimate(remoteExchange2, 10000)) .put("le", statsEstimate(localExchange, 6000)) + .put("p1", statsEstimate(p1, 6000)) .put("ts1", statsEstimate(ts1, 6000)) .put("ts2", statsEstimate(ts2, 1000)) .build(); Map types = ImmutableMap.of( "orderkey", BIGINT, + "orderkey_1", BIGINT, "orderkey_0", BIGINT); assertFragmentedEqualsUnfragmented(join, stats, types); @@ -548,14 +556,22 @@ private CostAssertionBuilder assertCostFragmentedPlan( Map stats, Map types) { - TypeProvider typeProvider = TypeProvider.copyOf(types.entrySet().stream() - .collect(ImmutableMap.toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))); + Map symbolTypes = types.entrySet().stream() + .collect(ImmutableMap.toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue)); + TypeProvider typeProvider = TypeProvider.copyOf(symbolTypes); StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider); CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session, typeProvider); - SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); + PlanNode plan = translateExpression(node, statsCalculator(stats), typeProvider); + SubPlan subPlan = fragment(new Plan(plan, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); return new CostAssertionBuilder(subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown())); } + private PlanNode translateExpression(PlanNode node, StatsCalculator statsCalculator, TypeProvider typeProvider) + { + IterativeOptimizer optimizer = new IterativeOptimizer(new RuleStatsRecorder(), statsCalculator, costCalculatorUsingExchanges, new TranslateExpressions(metadata, new SqlParser()).rules()); + return optimizer.optimize(node, session, typeProvider, new SymbolAllocator(typeProvider.allTypes()), new PlanNodeIdAllocator(), WarningCollector.NOOP); + } + private static class TestingCostProvider implements CostProvider { @@ -684,6 +700,7 @@ private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalcula .collect(ImmutableMap.toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))); StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider); CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session, typeProvider); + node = translateExpression(node, statsCalculator, typeProvider); SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); return subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()); } From d419732738b608dd68143a8c228661bfaeeb24c2 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:42 -0700 Subject: [PATCH 04/11] Catch exception in test unsupported subquery We don't support correlated subquery with multiple correlated columns. Instead expect it to return a invalid plan, expect an exception. --- .../com/facebook/presto/sql/planner/TestLogicalPlanner.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index d79308ee566cb..e9a1f884bab0d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -730,7 +730,7 @@ public void testSymbolsPrunedInCorrelatedInPredicateSource() anyTree(tableScan("orders"))))); } - @Test + @Test(expectedExceptions = RuntimeException.class, expectedExceptionsMessageRegExp = ".*Given correlated subquery is not supported.*") public void testDoubleNestedCorrelatedSubqueries() { assertPlan( @@ -749,7 +749,7 @@ public void testDoubleNestedCorrelatedSubqueries() any( any( tableScan("lineitem", ImmutableMap.of("L", "orderkey")))))))), - MorePredicates.isInstanceOfAny(AddLocalExchanges.class, CheckSubqueryNodesAreRewritten.class).negate()); + MorePredicates.isInstanceOfAny(AddLocalExchanges.class).negate()); } @Test From 60d1f349da3a57eca54daac766521ad5d4bd76d9 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:43 -0700 Subject: [PATCH 05/11] Fix RowExpressionVerifier cast assertion --- .../presto/sql/planner/assertions/RowExpressionVerifier.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index c0d733ba538a2..db0cda9d496c1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -137,7 +137,7 @@ protected Boolean visitCast(Cast expected, RowExpression actual) return false; } - if (!expected.getType().equals(actual.getType().toString())) { + if (!expected.getType().equalsIgnoreCase(actual.getType().toString())) { return false; } From f0d87e59c56b0564f649eaeff05688c9a645059c Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:44 -0700 Subject: [PATCH 06/11] Move toVariableReference method into PlannerUtils --- .../java/com/facebook/presto/sql/planner/PlannerUtils.java | 5 +++++ .../iterative/rule/PushAggregationThroughOuterJoin.java | 2 +- .../iterative/rule/PushProjectionThroughExchange.java | 2 +- .../presto/sql/planner/optimizations/AddExchanges.java | 7 +------ .../optimizations/OptimizeMixedDistinctAggregations.java | 2 +- .../sql/planner/optimizations/PropertyDerivations.java | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 7dff479a9e588..5bcfbcbdec64d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -69,4 +69,9 @@ public static OrderingScheme toOrderingScheme(List Streams.forEachPair(orderingSymbols.stream(), sortOrders.stream(), orderings::putIfAbsent); return new OrderingScheme(ImmutableList.copyOf(orderings.keySet()), orderings); } + + public static VariableReferenceExpression toVariableReference(Symbol symbol, TypeProvider types) + { + return variable(symbol.getName(), types.get(symbol)); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 6ce88d9248791..6253ea71c6763 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -51,7 +51,7 @@ import static com.facebook.presto.SystemSessionProperties.shouldPushAggregationThroughJoin; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 4dc35b308087c..d7513061fb2d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -39,7 +39,7 @@ import static com.facebook.presto.matching.Capture.newCapture; import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; -import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 9a9cebc0737c7..bc72830d47bad 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -106,6 +106,7 @@ import static com.facebook.presto.SystemSessionProperties.preferStreamingOperators; import static com.facebook.presto.sql.planner.FragmentTableScanCounter.getNumberOfTableScans; import static com.facebook.presto.sql.planner.FragmentTableScanCounter.hasMultipleTableScans; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; @@ -123,7 +124,6 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange; import static com.facebook.presto.sql.planner.plan.ExchangeNode.replicatedExchange; import static com.facebook.presto.sql.planner.plan.ExchangeNode.roundRobinExchange; -import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Preconditions.checkArgument; @@ -1383,11 +1383,6 @@ public static Map comp return outputToInput; } - public static VariableReferenceExpression toVariableReference(Symbol symbol, TypeProvider types) - { - return variable(symbol.getName(), types.get(symbol)); - } - @VisibleForTesting static Comparator streamingExecutionPreference(PreferredProperties preferred) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 8a3875f57adf8..6afd3cfef2c56 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -57,7 +57,7 @@ import static com.facebook.presto.SystemSessionProperties.isOptimizeDistinctAggregationEnabled; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; -import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.relational.Expressions.variable; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 393bf77396784..f92c589a67593 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -93,13 +93,13 @@ import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValuesToConstantExpressions; import static com.facebook.presto.spi.relation.DomainTranslator.BASIC_COLUMN_EXTRACTOR; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.coordinatorSingleStreamPartition; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.partitionedOn; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.singleStreamPartition; import static com.facebook.presto.sql.planner.optimizations.ActualProperties.Global.streamPartitionedOn; -import static com.facebook.presto.sql.planner.optimizations.AddExchanges.toVariableReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.Preconditions.checkArgument; From c19cfade4301c1259eafab700765214b8611a1ee Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:45 -0700 Subject: [PATCH 07/11] Remove unused getOutputSymbols --- .../com/facebook/presto/sql/planner/plan/Assignments.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 8aacfd8275637..6c28e4c77e958 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -89,11 +89,6 @@ public Assignments(@JsonProperty("assignments") Map getOutputSymbols() - { - return assignments.keySet().stream().map(VariableReferenceExpression::getName).map(Symbol::new).collect(toImmutableList()); - } - public List getOutputs() { return ImmutableList.copyOf(assignments.keySet()); From a034b1f6e7f273e32b8c36e88db3feaf78ec9260 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:47 -0700 Subject: [PATCH 08/11] Move identity functions into AssignmentUtils --- .../sql/planner/LocalExecutionPlanner.java | 3 +- .../presto/sql/planner/PlanFragmenter.java | 3 +- .../presto/sql/planner/QueryPlanner.java | 8 ++- .../presto/sql/planner/RelationPlanner.java | 5 +- .../iterative/rule/ExtractSpatialJoins.java | 5 +- .../iterative/rule/GatherAndMergeWindows.java | 6 +- .../rule/ImplementFilteredAggregations.java | 3 +- .../iterative/rule/InlineProjections.java | 6 +- .../rule/PushProjectionThroughExchange.java | 2 +- ...RewriteSpatialPartitioningAggregation.java | 3 +- .../TransformCorrelatedInPredicateToJoin.java | 9 ++- .../TransformCorrelatedScalarSubquery.java | 4 +- ...mCorrelatedSingleRowSubqueryToProject.java | 3 +- .../TransformExistsApplyToLateralNode.java | 3 +- .../sql/planner/iterative/rule/Util.java | 4 +- .../ImplementIntersectAndExceptAsUnion.java | 3 +- .../optimizations/IndexJoinOptimizer.java | 6 +- .../OptimizeMixedDistinctAggregations.java | 3 +- .../optimizations/PlanNodeDecorrelator.java | 3 +- .../optimizations/PredicatePushDown.java | 3 +- .../ScalarAggregationToJoinRewriter.java | 8 ++- ...uantifiedComparisonApplyToLateralJoin.java | 3 +- .../sql/planner/plan/AssignmentUtils.java | 67 +++++++++++++++++++ .../presto/sql/planner/plan/Assignments.java | 33 --------- ...tSimpleFilterProjectSemiJoinStatsRule.java | 38 +++++------ .../sql/planner/TestLogicalPlanner.java | 1 - .../iterative/TestIterativeOptimizer.java | 4 +- .../rule/TestAddIntermediateAggregations.java | 4 +- .../iterative/rule/TestInlineProjections.java | 7 +- .../rule/TestMergeAdjacentWindows.java | 6 +- .../rule/TestPruneAggregationColumns.java | 4 +- .../rule/TestPruneCrossJoinColumns.java | 4 +- .../rule/TestPruneFilterColumns.java | 4 +- .../rule/TestPruneIndexSourceColumns.java | 4 +- .../iterative/rule/TestPruneJoinColumns.java | 3 +- .../iterative/rule/TestPruneLimitColumns.java | 4 +- .../rule/TestPruneMarkDistinctColumns.java | 7 +- .../rule/TestPruneProjectColumns.java | 10 +-- .../rule/TestPruneSemiJoinColumns.java | 4 +- .../iterative/rule/TestPruneTopNColumns.java | 4 +- .../rule/TestPruneWindowColumns.java | 4 +- .../TestPushAggregationThroughOuterJoin.java | 3 +- 42 files changed, 187 insertions(+), 124 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index e447827bae8e9..ab03e4d4aca71 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -250,6 +250,7 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; @@ -1140,7 +1141,7 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext RowExpression filterExpression = node.getPredicate(); List outputVariables = node.getOutputVariables(); - return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), Assignments.identity(outputVariables), outputVariables); + return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identityAssignmentsAsSymbolReferences(outputVariables), outputVariables); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index 46785d972101d..edc36ab73328f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -102,6 +102,7 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_MATERIALIZED; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; @@ -656,7 +657,7 @@ private TableFinishNode createTemporaryTableWrite( sources = sources.stream() .map(source -> { Assignments.Builder assignments = Assignments.builder(); - assignments.putIdentities(source.getOutputVariables()); + assignments.putAll(identitiesAsSymbolReferences(source.getOutputVariables())); constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable))); return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); }) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 52458394d6586..d2ff000c99edd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -33,6 +33,7 @@ import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; +import com.facebook.presto.sql.planner.plan.AssignmentUtils; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.DeleteNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; @@ -85,6 +86,7 @@ import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toWindowType; import static com.facebook.presto.sql.planner.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.MoreObjects.firstNonNull; @@ -411,7 +413,7 @@ private PlanBuilder explicitCoercionSymbols(PlanBuilder subPlan, List assignments.put(key, new SymbolReference(value.getName()))); ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build()); @@ -673,7 +675,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica TranslationMap newTranslations = subPlan.copyTranslations(); Assignments.Builder projections = Assignments.builder(); - projections.putIdentities(subPlan.getRoot().getOutputVariables()); + projections.putAll(identitiesAsSymbolReferences(subPlan.getRoot().getOutputVariables())); List> descriptor = groupingSets.stream() .map(set -> set.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 6926c4c5b5e66..6296f4d6cac38 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -93,6 +93,7 @@ import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.google.common.base.Preconditions.checkArgument; @@ -434,8 +435,8 @@ If casts are redundant (due to column type and common type being equal), Assignments.Builder leftCoercions = Assignments.builder(); Assignments.Builder rightCoercions = Assignments.builder(); - leftCoercions.putIdentities(left.getRoot().getOutputVariables()); - rightCoercions.putIdentities(right.getRoot().getOutputVariables()); + leftCoercions.putAll(identitiesAsSymbolReferences(left.getRoot().getOutputVariables())); + rightCoercions.putAll(identitiesAsSymbolReferences(right.getRoot().getOutputVariables())); for (int i = 0; i < joinColumns.size(); i++) { Identifier identifier = joinColumns.get(i); Type type = analysis.getType(identifier); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index 21d2e2ef66b39..f205df0b59be6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -88,6 +88,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.ExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.planner.SymbolsExtractor.extractUniqueVariable; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.Patterns.filter; @@ -587,7 +588,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe { Assignments.Builder projections = Assignments.builder(); for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { - projections.putIdentity(outputVariable); + projections.put(identityAsSymbolReference(outputVariable)); } projections.put(variable, expression); @@ -598,7 +599,7 @@ private static PlanNode addPartitioningNodes(Context context, PlanNode node, Var { Assignments.Builder projections = Assignments.builder(); for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { - projections.putIdentity(outputVariable); + projections.put(identityAsSymbolReference(outputVariable)); } ImmutableList.Builder partitioningArguments = ImmutableList.builder() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index 10772d81a311d..170f6df503d01 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -43,6 +43,8 @@ import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.iterative.rule.Util.transpose; import static com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.dependsOn; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.isIdentity; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.window; @@ -141,7 +143,7 @@ protected static Optional pullWindowNodeAboveProjects( // The target node, when hoisted above the projections, will provide the symbols directly. Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( project.getAssignments().getMap(), - output -> !(project.getAssignments().isIdentity(output) && targetOutputs.contains(output))); + output -> !(isIdentity(project.getAssignments(), output) && targetOutputs.contains(output))); if (targetInputs.stream().anyMatch(assignmentsWithoutTargetOutputIdentities::containsKey)) { // Redefinition of an input to the target -- can't handle this case. @@ -150,7 +152,7 @@ protected static Optional pullWindowNodeAboveProjects( Assignments newAssignments = Assignments.builder() .putAll(assignmentsWithoutTargetOutputIdentities) - .putIdentities(targetInputs) + .putAll(identitiesAsSymbolReferences(targetInputs)) .build(); if (!newTargetChildOutputs.containsAll(extractUniqueVariable(newAssignments.getExpressions(), context.getSymbolAllocator().getTypes()))) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 4547ec9f00760..136505e7a74a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -33,6 +33,7 @@ import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -121,7 +122,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont } // identity projection for all existing inputs - newAssignments.putIdentities(aggregation.getSource().getOutputVariables()); + newAssignments.putAll(identitiesAsSymbolReferences(aggregation.getSource().getOutputVariables())); return Result.ofPlanNode( new AggregationNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index c2692540936cc..f96c90b6810f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -37,6 +37,8 @@ import static com.facebook.presto.matching.Capture.newCapture; import static com.facebook.presto.sql.planner.ExpressionSymbolInliner.inlineSymbols; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.isIdentity; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static java.util.stream.Collectors.toSet; @@ -97,7 +99,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) } } for (VariableReferenceExpression input : inputs) { - childAssignments.putIdentity(input); + childAssignments.put(identityAsSymbolReference(input)); } return Result.ofPlanNode( @@ -157,7 +159,7 @@ private Sets.SetView extractInliningTargets(Project Set singletons = dependencies.entrySet().stream() .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics - .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + .filter(entry -> !isIdentity(child.getAssignments(), entry.getKey())) // skip identities, otherwise, this rule will keep firing forever .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index d7513061fb2d8..438e3129a86ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -38,8 +38,8 @@ import static com.facebook.presto.matching.Capture.newCapture; import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; -import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index aaa34c1d22c11..b4414bef5c0f2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -40,6 +40,7 @@ import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; @@ -134,7 +135,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder() - .putIdentities(node.getSource().getOutputVariables()) + .putAll(identitiesAsSymbolReferences(node.getSource().getOutputVariables())) .put(partitionCountVariable, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) .putAll(envelopeAssignments.build()) .build()), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 1b60eceb43006..3a79808864f11 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -31,6 +31,7 @@ import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.AssignmentUtils; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -66,6 +67,7 @@ import static com.facebook.presto.sql.ExpressionUtils.and; import static com.facebook.presto.sql.ExpressionUtils.or; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.Apply.correlation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; @@ -180,7 +182,7 @@ private PlanNode buildInPredicateEquivalent( idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder() - .putIdentities(decorrelatedBuildSource.getOutputVariables()) + .putAll(identitiesAsSymbolReferences(decorrelatedBuildSource.getOutputVariables())) .put(buildSideKnownNonNull, bigint(0)) .build()); @@ -230,7 +232,7 @@ private PlanNode buildInPredicateEquivalent( idAllocator.getNextId(), aggregation, Assignments.builder() - .putIdentities(apply.getInput().getOutputVariables()) + .putAll(identitiesAsSymbolReferences(apply.getInput().getOutputVariables())) .put(inPredicateOutputVariable, inPredicateEquivalent) .build()); } @@ -327,7 +329,8 @@ public Optional visitProject(ProjectNode node, PlanNode reference) .map(SymbolReference.class::cast) .map(symbolReference -> new VariableReferenceExpression(symbolReference.getName(), types.get(Symbol.from(symbolReference)))) .filter(variable -> !correlation.contains(variable)) - .forEach(assignments::putIdentity); + .map(AssignmentUtils::identityAsSymbolReference) + .forEach(assignments::put); return new Decorrelated( decorrelated.getCorrelatedPredicates(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java index 75ddbf6358665..9230c26ea87a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarSubquery.java @@ -21,7 +21,6 @@ import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AssignUniqueId; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MarkDistinctNode; @@ -44,6 +43,7 @@ import static com.facebook.presto.spi.type.StandardTypes.BOOLEAN; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.LateralJoin.correlation; import static com.facebook.presto.sql.planner.plan.Patterns.lateralJoin; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -155,6 +155,6 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context return Result.ofPlanNode(new ProjectNode( context.getIdAllocator().getNextId(), filterNode, - Assignments.identity(lateralJoinNode.getOutputVariables()))); + identityAssignmentsAsSymbolReferences(lateralJoinNode.getOutputVariables()))); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java index 82485dfd6985c..56a4eff670c69 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java @@ -25,6 +25,7 @@ import java.util.List; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.lateralJoin; /** @@ -76,7 +77,7 @@ public Result apply(LateralJoinNode parent, Captures captures, Context context) } else if (subqueryProjections.size() == 1) { Assignments assignments = Assignments.builder() - .putIdentities(parent.getInput().getOutputVariables()) + .putAll(identitiesAsSymbolReferences(parent.getInput().getOutputVariables())) .putAll(subqueryProjections.get(0).getAssignments()) .build(); return Result.ofPlanNode(projectNode(parent.getInput(), assignments, context)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 936198ea241b3..0c5ed7bb4c7ee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -46,6 +46,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; @@ -121,7 +122,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C VariableReferenceExpression subqueryTrue = context.getSymbolAllocator().newVariable("subqueryTrue", BOOLEAN); Assignments.Builder assignments = Assignments.builder(); - assignments.putIdentities(applyNode.getInput().getOutputVariables()); + assignments.putAll(identitiesAsSymbolReferences(applyNode.getInput().getOutputVariables())); assignments.put(exists, new CoalesceExpression(ImmutableList.of(new SymbolReference(subqueryTrue.getName()), BooleanLiteral.FALSE_LITERAL))); PlanNode subquery = new ProjectNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java index 7c913cf71d658..5e30f606036c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/Util.java @@ -18,7 +18,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.TypeProvider; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; @@ -30,6 +29,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -83,7 +83,7 @@ public static Optional restrictOutputs(PlanNodeIdAllocator idAllocator new ProjectNode( idAllocator.getNextId(), node, - Assignments.identity(restrictedOutputs))); + identityAssignmentsAsSymbolReferences(restrictedOutputs))); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 040dac66d785f..c2447078a22e5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -56,6 +56,7 @@ import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -295,7 +296,7 @@ private ProjectNode project(PlanNode node, List col return new ProjectNode( idAllocator.getNextId(), node, - Assignments.identity(columns)); + identityAssignmentsAsSymbolReferences(columns)); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index 84beaec29ddf7..15de6d73ff02e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -30,7 +30,6 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.IndexSourceNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; @@ -57,6 +56,7 @@ import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @@ -162,7 +162,7 @@ else if (leftIndexCandidate.isPresent()) { indexJoinNode = new ProjectNode( idAllocator.getNextId(), indexJoinNode, - Assignments.identity(node.getOutputVariables())); + identityAssignmentsAsSymbolReferences(node.getOutputVariables())); } return indexJoinNode; @@ -211,7 +211,7 @@ private static PlanNode createIndexJoinWithExpectedOutputs( result = new ProjectNode( idAllocator.getNextId(), result, - Assignments.identity(expectedOutputs)); + identityAssignmentsAsSymbolReferences(expectedOutputs)); } return result; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 6afd3cfef2c56..82df34d78c6e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -60,6 +60,7 @@ import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference; import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; @@ -225,7 +226,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext visitProject(ProjectNode node, Void context Assignments assignments = Assignments.builder() .putAll(node.getAssignments()) - .putIdentities(variablesToAdd) + .putAll(identitiesAsSymbolReferences(variablesToAdd)) .build(); return Optional.of(new DecorrelationResult( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index d9143850e5697..30b9a1e582aa5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -84,6 +84,7 @@ import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic; import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; @@ -551,7 +552,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) } if (!node.getOutputVariables().equals(output.getOutputVariables())) { - output = new ProjectNode(idAllocator.getNextId(), output, Assignments.identity(node.getOutputVariables())); + output = new ProjectNode(idAllocator.getNextId(), output, identityAssignmentsAsSymbolReferences(node.getOutputVariables())); } return output; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index 8241dbbb2486a..3faffde8063ea 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -46,6 +46,8 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -81,7 +83,7 @@ public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, Aggreg VariableReferenceExpression nonNull = symbolAllocator.newVariable("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = Assignments.builder() - .putIdentities(source.get().getNode().getOutputVariables()) + .putAll(identitiesAsSymbolReferences(source.get().getNode().getOutputVariables())) .put(nonNull, TRUE_LITERAL) .build(); ProjectNode scalarAggregationSourceWithNonNullableVariable = new ProjectNode( @@ -142,7 +144,7 @@ private PlanNode rewriteScalarAggregation( if (subqueryProjection.isPresent()) { Assignments assignments = Assignments.builder() - .putIdentities(aggregationOutputVariables) + .putAll(identitiesAsSymbolReferences(aggregationOutputVariables)) .putAll(subqueryProjection.get().getAssignments()) .build(); @@ -155,7 +157,7 @@ private PlanNode rewriteScalarAggregation( return new ProjectNode( idAllocator.getNextId(), aggregationNode.get(), - Assignments.identity(aggregationOutputVariables)); + identityAssignmentsAsSymbolReferences(aggregationOutputVariables)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index 9213773d0cf7e..f4ef3278f8b83 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -59,6 +59,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; @@ -281,7 +282,7 @@ private static boolean shouldCompareValueWithLowerBound(QuantifiedComparisonExpr private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) { Assignments assignments = Assignments.builder() - .putIdentities(input.getOutputVariables()) + .putAll(identitiesAsSymbolReferences(input.getOutputVariables())) .putAll(subqueryAssignments) .build(); return new ProjectNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java new file mode 100644 index 0000000000000..281062e5f7bc1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; + +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonMap; + +public class AssignmentUtils +{ + private AssignmentUtils() {} + + @Deprecated + public static Map.Entry identityAsSymbolReference(VariableReferenceExpression variable) + { + return singletonMap(variable, asSymbolReference(variable)) + .entrySet().iterator().next(); + } + + @Deprecated + public static Map identitiesAsSymbolReferences(Collection variables) + { + Map map = new LinkedHashMap<>(); + for (VariableReferenceExpression variable : variables) { + map.put(variable, asSymbolReference(variable)); + } + return map; + } + + @Deprecated + public static Assignments identityAssignmentsAsSymbolReferences(Collection variables) + { + return Assignments.builder().putAll(identitiesAsSymbolReferences(variables)).build(); + } + + public static boolean isIdentity(Assignments assignments, VariableReferenceExpression output) + { + //TODO this will be checking against VariableExpression once getOutput returns VariableReferenceExpression + Expression expression = assignments.get(output); + return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); + } + + @Deprecated + public static Assignments identityAssignmentsAsSymbolReferences(VariableReferenceExpression... variables) + { + return identityAssignmentsAsSymbolReferences(asList(variables)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 6c28e4c77e958..799762cc0e990 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -18,7 +18,6 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.sql.tree.SymbolReference; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Predicate; @@ -39,7 +38,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; public class Assignments @@ -49,16 +47,6 @@ public static Builder builder() return new Builder(); } - public static Assignments identity(VariableReferenceExpression... variables) - { - return identity(asList(variables)); - } - - public static Assignments identity(Iterable variables) - { - return builder().putIdentities(variables).build(); - } - public static Assignments copyOf(Map assignments) { return builder() @@ -124,13 +112,6 @@ public Assignments filter(Predicate predicate) .collect(toAssignments()); } - public boolean isIdentity(VariableReferenceExpression output) - { - Expression expression = assignments.get(output); - - return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); - } - private Collector, Builder, Assignments> toAssignments() { return Collector.of( @@ -254,20 +235,6 @@ public Builder put(Entry assignment) return this; } - public Builder putIdentities(Iterable variables) - { - for (VariableReferenceExpression variable : variables) { - putIdentity(variable); - } - return this; - } - - public Builder putIdentity(VariableReferenceExpression variable) - { - put(variable, new SymbolReference(variable.getName())); - return this; - } - public Assignments build() { return new Assignments(assignments); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java index ecdb2994731a8..2c33990005326 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestSimpleFilterProjectSemiJoinStatsRule.java @@ -19,7 +19,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.Expression; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -29,6 +28,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; public class TestSimpleFilterProjectSemiJoinStatsRule extends BaseStatsCalculatorTest @@ -99,10 +99,10 @@ public void testFilterPositiveSemiJoin(boolean toRowExpression) .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(180) - .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) - .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) - .variableStatsUnknown("c") - .variableStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") @@ -127,9 +127,9 @@ public void testFilterPositiveNarrowingProjectSemiJoin(boolean toRowExpression) if (toRowExpression) { return pb.filter( TRANSLATOR.translateAndOptimize(expression("sjo"), pb.getTypes()), - pb.project(Assignments.identity(semiJoinOutput, a), semiJoinNode)); + pb.project(identityAssignmentsAsSymbolReferences(semiJoinOutput, a), semiJoinNode)); } - return pb.filter(expression("sjo"), pb.project(Assignments.identity(semiJoinOutput, a), semiJoinNode)); + return pb.filter(expression("sjo"), pb.project(identityAssignmentsAsSymbolReferences(semiJoinOutput, a), semiJoinNode)); }) .withSourceStats(LEFT_SOURCE_ID, PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) @@ -141,10 +141,10 @@ public void testFilterPositiveNarrowingProjectSemiJoin(boolean toRowExpression) .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(180) - .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) - .variableStatsUnknown("b") - .variableStatsUnknown("c") - .variableStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedAInC)) + .variableStatsUnknown("b") + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") @@ -161,10 +161,10 @@ public void testFilterPositivePlusExtraConjunctSemiJoin(boolean toRowExpression) .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(144) - .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInC)) - .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) - .variableStatsUnknown("c") - .variableStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInC)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } @Test(dataProvider = "toRowExpression") @@ -181,10 +181,10 @@ public void testFilterNegativeSemiJoin(boolean toRowExpression) .addVariableStatistics(new VariableReferenceExpression("c", BIGINT), cStats) .build()) .check(check -> check.outputRowsCount(720) - .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInCWithExtraFilter)) - .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) - .variableStatsUnknown("c") - .variableStatsUnknown("sjo")); + .variableStats(new VariableReferenceExpression("a", BIGINT), assertion -> assertion.isEqualTo(expectedANotInCWithExtraFilter)) + .variableStats(new VariableReferenceExpression("b", BIGINT), assertion -> assertion.isEqualTo(bStats)) + .variableStatsUnknown("c") + .variableStatsUnknown("sjo")); } private StatsCalculatorAssertion getStatsCalculatorAssertion(Expression expression, boolean toRowExpression) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index e9a1f884bab0d..ced3a86ac8905 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -22,7 +22,6 @@ import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; -import com.facebook.presto.sql.planner.optimizations.CheckSubqueryNodesAreRewritten; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java index b649e6d47c27b..96c2f40c3967a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/TestIterativeOptimizer.java @@ -21,7 +21,6 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.tpch.TpchConnectorFactory; @@ -33,6 +32,7 @@ import org.testng.annotations.Test; import static com.facebook.presto.spi.StandardErrorCode.OPTIMIZER_TIMEOUT; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; @@ -107,7 +107,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) if (isIdentityProjection(project)) { return Result.ofPlanNode(project.getSource()); } - PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, Assignments.identity(project.getOutputVariables())); + PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), project, identityAssignmentsAsSymbolReferences(project.getOutputVariables())); return Result.ofPlanNode(projectNode); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java index 2f8cde762a7f9..c7d339b282373 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.java @@ -17,7 +17,6 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.tree.FunctionCall; import com.google.common.collect.ImmutableList; @@ -39,6 +38,7 @@ import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.INTERMEDIATE; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; @@ -316,7 +316,7 @@ public void testInterimProject() p.gatheringExchange( ExchangeNode.Scope.REMOTE_STREAMING, p.project( - Assignments.identity(p.variable("b")), + identityAssignmentsAsSymbolReferences(p.variable("b")), p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.variable(p.symbol("b")), expression("count(a)"), ImmutableList.of(BIGINT)) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index 10cf25e2c433b..f31e91a1d18f8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -23,6 +23,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; public class TestInlineProjections extends BaseRuleTest @@ -75,7 +76,7 @@ public void testIdentityProjections() p.project( Assignments.of(p.variable("output"), expression("value")), p.project( - Assignments.identity(p.variable("value")), + identityAssignmentsAsSymbolReferences(p.variable("value")), p.values(p.variable("value"))))) .doesNotFire(); } @@ -86,9 +87,9 @@ public void testSubqueryProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.identity(p.variable("fromOuterScope"), p.variable("value")), + identityAssignmentsAsSymbolReferences(p.variable("fromOuterScope"), p.variable("value")), p.project( - Assignments.identity(p.variable("value")), + identityAssignmentsAsSymbolReferences(p.variable("value")), p.values(p.variable("value"))))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 72b0e87ee33e8..f208afb3696c8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -39,6 +39,8 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; @@ -181,10 +183,10 @@ public void testIntermediateProjectNodes() p.project( Assignments.builder() .put(p.variable("one"), expression("CAST(1 AS bigint)")) - .putIdentities(ImmutableList.of(p.variable("a"), p.variable("avgOutput"))) + .putAll(identitiesAsSymbolReferences(ImmutableList.of(p.variable("a"), p.variable("avgOutput")))) .build(), p.project( - Assignments.identity(p.variable("a"), p.variable("avgOutput"), p.variable("unused")), + identityAssignmentsAsSymbolReferences(p.variable("a"), p.variable("avgOutput"), p.variable("unused")), p.window( newWindowNodeSpecification(p, "a"), ImmutableMap.of(p.variable(p.symbol("avgOutput")), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, "a")), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java index cddb388e5447f..757143997c1a3 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneAggregationColumns.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -32,6 +31,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -71,7 +71,7 @@ private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate VariableReferenceExpression b = planBuilder.variable("b"); VariableReferenceExpression key = planBuilder.variable("key"); return planBuilder.project( - Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), + identityAssignmentsAsSymbolReferences(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .source(planBuilder.values(key)) .singleGroupingSet(key) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java index 3ef684df44855..4624f6495c519 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCrossJoinColumns.java @@ -18,7 +18,6 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; @@ -32,6 +31,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.google.common.collect.ImmutableList.toImmutableList; public class TestPruneCrossJoinColumns @@ -89,7 +89,7 @@ private static PlanNode buildProjectedCrossJoin(PlanBuilder p, Predicate outputs = ImmutableList.of(leftValue, rightValue); return p.project( - Assignments.identity( + identityAssignmentsAsSymbolReferences( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java index d59d696198f69..449122de058c1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneFilterColumns.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -28,6 +27,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -70,7 +70,7 @@ private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate 5"), planBuilder.values(a, b))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java index 8035bcc3a2e8c..0525523dad39e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneIndexSourceColumns.java @@ -22,7 +22,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.testing.TestingTransactionHandle; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; @@ -41,6 +40,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.constrainedIndexSource; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -80,7 +80,7 @@ private static PlanNode buildProjectedIndexSource(PlanBuilder p, Predicate outputs = ImmutableList.of(leftKey, leftValue, rightKey, rightValue); return p.project( - Assignments.identity( + identityAssignmentsAsSymbolReferences( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java index 2f41ac6646332..29de47e139380 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneLimitColumns.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -28,6 +27,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.google.common.base.Predicates.alwaysTrue; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -62,7 +62,7 @@ private ProjectNode buildProjectedLimit(PlanBuilder planBuilder, Predicate outputs = ImmutableList.of(match, leftKey, leftKeyHash, leftValue); return p.project( - Assignments.identity( + identityAssignmentsAsSymbolReferences( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java index f994bda05f389..861be3b10bc78 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTopNColumns.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; @@ -30,6 +29,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.topN; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -76,7 +76,7 @@ private ProjectNode buildProjectedTopN(PlanBuilder planBuilder, Predicate filteredInputs = inputs.stream().filter(sourceFilter).collect(toImmutableList()); return p.project( - Assignments.identity( + identityAssignmentsAsSymbolReferences( outputs.stream() .filter(projectionFilter) .collect(toImmutableList())), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index 2364a63c19789..c913a8d21e683 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -38,6 +38,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; @@ -202,7 +203,7 @@ public void testDoesNotFireWhenNotDistinct() p.join( JoinNode.Type.LEFT, p.project(Assignments.builder() - .putIdentity(p.variable("COL1", BIGINT)) + .put(identityAsSymbolReference(p.variable("COL1", BIGINT))) .build(), p.aggregation(builder -> builder.singleGroupingSet(p.variable("COL1"), p.variable("unused")) From 0463baa797ad9cb6c5ee98c3f952e751ab3f4875 Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:48 -0700 Subject: [PATCH 09/11] Replace Expression in Assignments to RowExpression --- .../presto/cost/ProjectStatsRule.java | 14 ++- .../planner/EffectivePredicateExtractor.java | 3 +- .../sql/planner/ExpressionExtractor.java | 4 +- .../sql/planner/LocalExecutionPlanner.java | 47 +++------- .../presto/sql/planner/LogicalPlanner.java | 7 +- .../presto/sql/planner/PlanBuilder.java | 5 +- .../presto/sql/planner/PlanFragmenter.java | 22 ++--- .../presto/sql/planner/QueryPlanner.java | 20 +++-- .../presto/sql/planner/RelationPlanner.java | 23 ++--- .../presto/sql/planner/SubqueryPlanner.java | 10 +-- .../presto/sql/planner/SymbolsExtractor.java | 20 +++++ .../iterative/rule/EliminateCrossJoins.java | 4 +- .../rule/ExpressionRewriteRuleSet.java | 5 +- .../iterative/rule/ExtractSpatialJoins.java | 4 +- .../iterative/rule/GatherAndMergeWindows.java | 4 +- .../rule/ImplementFilteredAggregations.java | 2 +- .../iterative/rule/InlineProjections.java | 34 +++++--- .../rule/ProjectOffPushDownRule.java | 7 +- .../rule/PushAggregationThroughOuterJoin.java | 7 +- ...PushPartialAggregationThroughExchange.java | 2 +- .../rule/PushProjectionThroughExchange.java | 20 +++-- .../rule/PushProjectionThroughUnion.java | 11 ++- ...RewriteSpatialPartitioningAggregation.java | 9 +- .../rule/SimplifyCountOverConstant.java | 2 +- .../TransformCorrelatedInPredicateToJoin.java | 6 +- .../TransformExistsApplyToLateralNode.java | 10 ++- ...rrelatedInPredicateSubqueryToSemiJoin.java | 3 +- .../planner/optimizations/AddExchanges.java | 7 +- .../planner/optimizations/ApplyNodeUtil.java | 3 +- .../HashGenerationOptimizer.java | 18 ++-- .../ImplementIntersectAndExceptAsUnion.java | 4 +- .../optimizations/IndexJoinOptimizer.java | 44 ++++++++-- .../optimizations/MetadataQueryOptimizer.java | 4 +- .../OptimizeMixedDistinctAggregations.java | 10 +-- .../optimizations/PredicatePushDown.java | 36 ++++---- .../optimizations/PropertyDerivations.java | 65 +++++++++++--- .../PruneUnreferencedOutputs.java | 9 +- .../optimizations/PushdownSubfields.java | 5 +- .../ScalarAggregationToJoinRewriter.java | 2 +- .../StreamPropertyDerivations.java | 28 +++++- ...uantifiedComparisonApplyToLateralJoin.java | 5 +- .../optimizations/TranslateExpressions.java | 86 +++++++++++++++++++ .../UnaliasSymbolReferences.java | 6 +- .../optimizations/joins/JoinGraph.java | 4 +- .../presto/sql/planner/plan/ApplyNode.java | 3 +- .../sql/planner/plan/AssignmentUtils.java | 44 ++++++++-- .../presto/sql/planner/plan/Assignments.java | 67 ++++++--------- .../sql/planner/planPrinter/PlanPrinter.java | 6 +- .../sql/planner/sanity/TypeValidator.java | 25 ++++-- .../sanity/ValidateDependenciesChecker.java | 6 +- .../relational/OriginalExpressionUtils.java | 2 +- .../sql/relational/ProjectNodeUtils.java | 22 ++++- .../facebook/presto/util/GraphvizPrinter.java | 8 +- .../presto/cost/TestCostCalculator.java | 4 +- .../TestEffectivePredicateExtractor.java | 4 +- .../sql/planner/TestLogicalPlanner.java | 6 +- .../presto/sql/planner/TestTypeValidator.java | 13 +-- .../planner/assertions/ExpressionMatcher.java | 34 +++++--- .../assertions/PlanMatchingVisitor.java | 5 +- .../sql/planner/assertions/SymbolAliases.java | 16 +++- .../rule/TestEliminateCrossJoins.java | 4 +- .../rule/TestExpressionRewriteRuleSet.java | 10 +-- .../iterative/rule/TestInlineProjections.java | 26 +++--- .../rule/TestMergeAdjacentWindows.java | 3 +- .../TestPruneCountAggregationOverScalar.java | 8 +- .../rule/TestPruneMarkDistinctColumns.java | 6 +- .../rule/TestPruneTableScanColumns.java | 6 +- .../rule/TestPruneValuesColumns.java | 6 +- .../TestPushAggregationThroughOuterJoin.java | 7 +- .../rule/TestPushLimitThroughProject.java | 8 +- .../TestPushProjectionThroughExchange.java | 23 ++--- .../rule/TestPushProjectionThroughUnion.java | 8 +- ...estRemoveUnreferencedScalarApplyNodes.java | 3 +- ...formCorrelatedScalarAggregationToJoin.java | 6 +- ...TestTransformCorrelatedScalarSubquery.java | 8 +- ...mCorrelatedSingleRowSubqueryToProject.java | 4 +- ...TestTransformExistsApplyToLateralJoin.java | 7 +- ...rrelatedInPredicateSubqueryToSemiJoin.java | 5 +- .../iterative/rule/test/PlanBuilder.java | 18 +++- .../iterative/rule/test/TestRuleTester.java | 4 +- .../sql/planner/plan/TestAssingments.java | 3 +- 81 files changed, 694 insertions(+), 385 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java index afc3824374b52..ea2406ef7a87f 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ProjectStatsRule.java @@ -15,16 +15,18 @@ import com.facebook.presto.Session; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.tree.Expression; import java.util.Map; import java.util.Optional; import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static java.util.Objects.requireNonNull; public class ProjectStatsRule @@ -53,8 +55,14 @@ protected Optional doCalculate(ProjectNode node, StatsPro PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() .setOutputRowCount(sourceStats.getOutputRowCount()); - for (Map.Entry entry : node.getAssignments().entrySet()) { - calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(entry.getValue(), sourceStats, session, types)); + for (Map.Entry entry : node.getAssignments().entrySet()) { + RowExpression expression = entry.getValue(); + if (isExpression(expression)) { + calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(castToExpression(expression), sourceStats, session, types)); + } + else { + calculatedStats.addVariableStatistics(entry.getKey(), scalarStatsCalculator.calculate(expression, sourceStats, session)); + } } return Optional.of(calculatedStats.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java index 19f4f92384011..ab9c09e713373 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java @@ -62,6 +62,7 @@ import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Predicates.in; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.util.Objects.requireNonNull; /** @@ -164,7 +165,7 @@ public Expression visitProject(ProjectNode node, Void context) Expression underlyingPredicate = node.getSource().accept(this, context); - List projectionEqualities = node.getAssignments().entrySet().stream() + List projectionEqualities = transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression).entrySet().stream() .filter(VARIABLE_MATCHES_EXPRESSION.negate()) .map(VARIABLE_ENTRY_TO_EQUALITY) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java index 2ded4fe031662..89ebaf6bf1418 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionExtractor.java @@ -23,7 +23,6 @@ import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.ValuesNode; -import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import java.util.List; @@ -112,7 +111,7 @@ public Void visitFilter(FilterNode node, ImmutableList.Builder co @Override public Void visitProject(ProjectNode node, ImmutableList.Builder context) { - context.addAll(node.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())); + context.addAll(node.getAssignments().getExpressions().stream().collect(toImmutableList())); return super.visitProject(node, context); } @@ -136,7 +135,6 @@ public Void visitApply(ApplyNode node, ImmutableList.Builder cont context.addAll(node.getSubqueryAssignments() .getExpressions() .stream() - .map(OriginalExpressionUtils::castToRowExpression) .collect(toImmutableList())); return super.visitApply(node, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index ab03e4d4aca71..285329db0b76f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -221,7 +221,6 @@ import static com.facebook.presto.SystemSessionProperties.getTaskWriterCount; import static com.facebook.presto.SystemSessionProperties.isExchangeCompressionEnabled; import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; -import static com.facebook.presto.execution.warnings.WarningCollector.NOOP; import static com.facebook.presto.operator.DistinctLimitOperator.DistinctLimitOperatorFactory; import static com.facebook.presto.operator.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; import static com.facebook.presto.operator.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory; @@ -240,7 +239,6 @@ import static com.facebook.presto.spi.relation.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; -import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; @@ -250,7 +248,7 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.PARTIAL; -import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; @@ -272,11 +270,9 @@ import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.concat; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; import static io.airlift.units.DataSize.Unit.BYTE; -import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1141,7 +1137,7 @@ public PhysicalOperation visitFilter(FilterNode node, LocalExecutionPlanContext RowExpression filterExpression = node.getPredicate(); List outputVariables = node.getOutputVariables(); - return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identityAssignmentsAsSymbolReferences(outputVariables), outputVariables); + return visitScanFilterAndProject(context, node.getId(), sourceNode, Optional.of(filterExpression), identityAssignments(outputVariables), outputVariables); } @Override @@ -1212,30 +1208,15 @@ private PhysicalOperation visitScanFilterAndProject( Map outputMappings = outputMappingsBuilder.build(); // compiler uses inputs instead of symbols, so rewrite the expressions first - - List projections = new ArrayList<>(); - for (VariableReferenceExpression variable : outputVariables) { - projections.add(assignments.get(variable)); - } - - Map, Type> expressionTypes = getExpressionTypes( - context.getSession(), - metadata, - sqlParser, - context.getTypes(), - concat(assignments.getExpressions()), - emptyList(), - NOOP, - false); - - List translatedProjections = projections.stream() - .map(expression -> toRowExpression(expression, expressionTypes, sourceLayout)) + List projections = outputVariables.stream() + .map(assignments::get) + .map(expression -> bindChannels(expression, sourceLayout)) .collect(toImmutableList()); try { if (columns != null) { - Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, translatedProjections, sourceNode.getId()); - Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); + Supplier cursorProcessor = expressionCompiler.compileCursorProcessor(filterExpression, projections, sourceNode.getId()); + Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId)); SourceOperatorFactory operatorFactory = new ScanFilterAndProjectOperatorFactory( context.getNextOperatorId(), @@ -1245,20 +1226,20 @@ private PhysicalOperation visitScanFilterAndProject( cursorProcessor, pageProcessor, columns, - getTypes(projections, expressionTypes), + projections.stream().map(RowExpression::getType).collect(toImmutableList()), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); return new PhysicalOperation(operatorFactory, outputMappings, context, stageExecutionDescriptor.isScanGroupedExecution(sourceNode.getId()) ? GROUPED_EXECUTION : UNGROUPED_EXECUTION); } else { - Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, translatedProjections, Optional.of(context.getStageId() + "_" + planNodeId)); + Supplier pageProcessor = expressionCompiler.compilePageProcessor(filterExpression, projections, Optional.of(context.getStageId() + "_" + planNodeId)); OperatorFactory operatorFactory = new FilterAndProjectOperator.FilterAndProjectOperatorFactory( context.getNextOperatorId(), planNodeId, pageProcessor, - getTypes(projections, expressionTypes), + projections.stream().map(RowExpression::getType).collect(toImmutableList()), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session)); @@ -2730,14 +2711,6 @@ private OperatorFactory createHashAggregationOperatorFactory( } } - private static List getTypes(List expressions, Map, Type> expressionTypes) - { - return expressions.stream() - .map(NodeRef::of) - .map(expressionTypes::get) - .collect(toImmutableList()); - } - private static TableFinisher createTableFinisher(Session session, TableFinishNode node, Metadata metadata) { WriterTarget target = node.getTarget(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 0697a997b6289..312d2660e8630 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -99,6 +99,7 @@ import static com.facebook.presto.sql.planner.plan.TableWriterNode.WriterTarget; import static com.facebook.presto.sql.planner.sanity.PlanSanityChecker.DISTRIBUTED_PLAN_SANITY_CHECKER; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -358,7 +359,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) int index = insert.getColumns().indexOf(columns.get(column.getName())); if (index < 0) { Expression cast = new Cast(new NullLiteral(), column.getType().getTypeSignature().toString()); - assignments.put(output, cast); + assignments.put(output, castToRowExpression(cast)); } else { Symbol input = plan.getSymbol(index); @@ -366,11 +367,11 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) Type queryType = symbolAllocator.getTypes().get(input); if (queryType.equals(tableType) || metadata.getTypeManager().isTypeOnlyCoercion(queryType, tableType)) { - assignments.put(output, input.toSymbolReference()); + assignments.put(output, castToRowExpression(input.toSymbolReference())); } else { Expression cast = new Cast(input.toSymbolReference(), tableType.getTypeSignature().toString()); - assignments.put(output, cast); + assignments.put(output, castToRowExpression(cast)); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index 991545b032614..f2e5696e478fd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static java.util.Objects.requireNonNull; class PlanBuilder @@ -105,13 +106,13 @@ public PlanBuilder appendProjections(Iterable expressions, SymbolAll // add an identity projection for underlying plan for (VariableReferenceExpression variable : getRoot().getOutputVariables()) { - projections.put(variable, new SymbolReference(variable.getName())); + projections.put(variable, castToRowExpression(new SymbolReference(variable.getName()))); } ImmutableMap.Builder newTranslations = ImmutableMap.builder(); for (Expression expression : expressions) { VariableReferenceExpression variable = symbolAllocator.newVariable(expression, getAnalysis().getTypeWithCoercions(expression)); - projections.put(variable, translations.rewrite(expression)); + projections.put(variable, castToRowExpression(translations.rewrite(expression))); newTranslations.put(variable, expression); } // Now append the new translations into the TranslationMap diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index edc36ab73328f..c6b7e268eb7a0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; @@ -69,7 +70,6 @@ import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; import com.facebook.presto.sql.planner.sanity.PlanSanityChecker; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -102,7 +102,6 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_MATERIALIZED; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; @@ -110,6 +109,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonFragmentPlan; +import static com.facebook.presto.sql.relational.Expressions.constant; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -561,13 +561,13 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite private PartitioningVariableAssignments assignPartitioningVariables(Partitioning partitioning) { ImmutableList.Builder variables = ImmutableList.builder(); - ImmutableMap.Builder constants = ImmutableMap.builder(); + ImmutableMap.Builder constants = ImmutableMap.builder(); for (ArgumentBinding argumentBinding : partitioning.getArguments()) { VariableReferenceExpression variable; if (argumentBinding.isConstant()) { ConstantExpression constant = argumentBinding.getConstant(); - Expression expression = literalEncoder.toExpression(constant.getValue(), constant.getType()); - variable = symbolAllocator.newVariable(expression, constant.getType()); + RowExpression expression = constant(constant.getValue(), constant.getType()); + variable = symbolAllocator.newVariable("constant_partition", constant.getType()); constants.put(variable, expression); } else { @@ -633,7 +633,7 @@ private TableFinishNode createTemporaryTableWrite( List outputs, List> inputs, List sources, - Map constantExpressions, + Map constantExpressions, PartitioningMetadata partitioningMetadata) { if (!constantExpressions.isEmpty()) { @@ -657,8 +657,8 @@ private TableFinishNode createTemporaryTableWrite( sources = sources.stream() .map(source -> { Assignments.Builder assignments = Assignments.builder(); - assignments.putAll(identitiesAsSymbolReferences(source.getOutputVariables())); - constantVariables.forEach(variable -> assignments.put(variable, constantExpressions.get(variable))); + source.getOutputVariables().forEach(variable -> assignments.put(variable, new VariableReferenceExpression(variable.getName(), variable.getType()))); + constantVariables.forEach(symbol -> assignments.put(symbol, constantExpressions.get(symbol))); return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); }) .collect(toImmutableList()); @@ -1217,9 +1217,9 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) private static class PartitioningVariableAssignments { private final List variables; - private final Map constants; + private final Map constants; - private PartitioningVariableAssignments(List variables, Map constants) + private PartitioningVariableAssignments(List variables, Map constants) { this.variables = ImmutableList.copyOf(requireNonNull(variables, "variables is null")); this.constants = ImmutableMap.copyOf(requireNonNull(constants, "constants is null")); @@ -1233,7 +1233,7 @@ public List getVariables() return variables; } - public Map getConstants() + public Map getConstants() { return constants; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index d2ff000c99edd..1ff3c4c39608f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.Analysis; @@ -88,6 +89,7 @@ import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -338,13 +340,13 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression for (Expression expression : expressions) { if (expression instanceof SymbolReference) { VariableReferenceExpression variable = symbolAllocator.toVariableReference(Symbol.from(expression)); - projections.put(variable, expression); + projections.put(variable, castToRowExpression(expression)); outputTranslations.put(expression, variable); continue; } VariableReferenceExpression variable = symbolAllocator.newVariable(expression, analysis.getTypeWithCoercions(expression)); - projections.put(variable, subPlan.rewrite(expression)); + projections.put(variable, castToRowExpression(subPlan.rewrite(expression))); outputTranslations.put(expression, variable); } @@ -355,9 +357,9 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression analysis.getParameters()); } - private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) + private Map coerce(Iterable expressions, PlanBuilder subPlan, TranslationMap translations) { - ImmutableMap.Builder projections = ImmutableMap.builder(); + ImmutableMap.Builder projections = ImmutableMap.builder(); for (Expression expression : expressions) { Type type = analysis.getType(expression); @@ -371,7 +373,7 @@ private Map coerce(Iterable assignments.put(key, new SymbolReference(value.getName()))); + groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(asSymbolReference(value)))); ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build()); subPlan = new PlanBuilder(groupingTranslations, project, analysis.getParameters()); @@ -694,7 +696,7 @@ private PlanBuilder handleGroupingOperations(PlanBuilder subPlan, QuerySpecifica false, metadata.getTypeManager().isTypeOnlyCoercion(analysis.getType(groupingOperation), coercion)); } - projections.put(variable, rewritten); + projections.put(variable, castToRowExpression(rewritten)); newTranslations.put(groupingOperation, variable); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 6296f4d6cac38..e24563f4d5fc3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -94,6 +94,7 @@ import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.google.common.base.Preconditions.checkArgument; @@ -185,7 +186,7 @@ protected RelationPlan visitAliasedRelation(AliasedRelation node, Void context) Field field = subPlan.getDescriptor().getFieldByIndex(i); if (!field.isHidden()) { VariableReferenceExpression aliasedColumn = symbolAllocator.newVariable(field); - assignments.put(aliasedColumn, (new Symbol(subPlan.getFieldMappings().get(i).getName())).toSymbolReference()); + assignments.put(aliasedColumn, castToRowExpression(asSymbolReference(subPlan.getFieldMappings().get(i)))); newMappings.add(aliasedColumn); } } @@ -444,21 +445,21 @@ If casts are redundant (due to column type and common type being equal), // compute the coercion for the field on the left to the common supertype of left & right VariableReferenceExpression leftOutput = symbolAllocator.newVariable(identifier, type); int leftField = joinAnalysis.getLeftJoinFields().get(i); - leftCoercions.put(leftOutput, new Cast( + leftCoercions.put(leftOutput, castToRowExpression(new Cast( left.getSymbol(leftField).toSymbolReference(), type.getTypeSignature().toString(), false, - metadata.getTypeManager().isTypeOnlyCoercion(left.getDescriptor().getFieldByIndex(leftField).getType(), type))); + metadata.getTypeManager().isTypeOnlyCoercion(left.getDescriptor().getFieldByIndex(leftField).getType(), type)))); leftJoinColumns.put(identifier, leftOutput); // compute the coercion for the field on the right to the common supertype of left & right VariableReferenceExpression rightOutput = symbolAllocator.newVariable(identifier, type); int rightField = joinAnalysis.getRightJoinFields().get(i); - rightCoercions.put(rightOutput, new Cast( + rightCoercions.put(rightOutput, castToRowExpression(new Cast( right.getSymbol(rightField).toSymbolReference(), type.getTypeSignature().toString(), false, - metadata.getTypeManager().isTypeOnlyCoercion(right.getDescriptor().getFieldByIndex(rightField).getType(), type))); + metadata.getTypeManager().isTypeOnlyCoercion(right.getDescriptor().getFieldByIndex(rightField).getType(), type)))); rightJoinColumns.put(identifier, rightOutput); clauses.add(new JoinNode.EquiJoinClause(leftOutput, rightOutput)); @@ -490,21 +491,21 @@ If casts are redundant (due to column type and common type being equal), for (Identifier column : joinColumns) { VariableReferenceExpression output = symbolAllocator.newVariable(column, analysis.getType(column)); outputs.add(output); - assignments.put(output, new CoalesceExpression( + assignments.put(output, castToRowExpression(new CoalesceExpression( new SymbolReference(leftJoinColumns.get(column).getName()), - new SymbolReference(rightJoinColumns.get(column).getName()))); + new SymbolReference(rightJoinColumns.get(column).getName())))); } for (int field : joinAnalysis.getOtherLeftFields()) { VariableReferenceExpression variable = left.getFieldMappings().get(field); outputs.add(variable); - assignments.put(variable, new SymbolReference(variable.getName())); + assignments.put(variable, castToRowExpression(new SymbolReference(variable.getName()))); } for (int field : joinAnalysis.getOtherRightFields()) { VariableReferenceExpression variable = right.getFieldMappings().get(field); outputs.add(variable); - assignments.put(variable, new SymbolReference(variable.getName())); + assignments.put(variable, castToRowExpression(new SymbolReference(variable.getName()))); } return new RelationPlan( @@ -738,13 +739,13 @@ private RelationPlan addCoercions(RelationPlan plan, Type[] targetColumnTypes) if (!outputType.equals(inputVariable.getType())) { Expression cast = new Cast(inputSymbol.toSymbolReference(), outputType.getTypeSignature().toString()); VariableReferenceExpression outputVariable = symbolAllocator.newVariable(cast, outputType); - assignments.put(outputVariable, cast); + assignments.put(outputVariable, castToRowExpression(cast)); newVariables.add(outputVariable); } else { SymbolReference symbolReference = inputSymbol.toSymbolReference(); VariableReferenceExpression outputVariable = symbolAllocator.newVariable(symbolReference, outputType); - assignments.put(outputVariable, symbolReference); + assignments.put(outputVariable, castToRowExpression(symbolReference)); newVariables.add(outputVariable); } Field oldField = oldDescriptor.getFieldByIndex(i); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java index 815a0116d99db..7101f5a2de0b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SubqueryPlanner.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.AssignmentUtils; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; @@ -201,7 +202,7 @@ private PlanBuilder appendInPredicateApplyNode(PlanBuilder subPlan, InPredicate subPlan.getTranslations().put(inPredicate, inPredicateSubqueryVariable); - return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubqueryVariable, inPredicateSubqueryExpression), correlationAllowed); + return appendApplyNode(subPlan, inPredicate, subqueryPlan.getRoot(), Assignments.of(inPredicateSubqueryVariable, castToRowExpression(inPredicateSubqueryExpression)), correlationAllowed); } private PlanBuilder appendScalarSubqueryApplyNodes(PlanBuilder builder, Set scalarSubqueries, boolean correlationAllowed) @@ -300,7 +301,7 @@ private PlanBuilder appendExistSubqueryApplyNode(PlanBuilder subPlan, ExistsPred subPlan, existsPredicate.getSubquery(), subqueryNode, - Assignments.of(exists, rewrittenExistsPredicate), + Assignments.of(exists, castToRowExpression(rewrittenExistsPredicate)), correlationAllowed); } @@ -398,7 +399,7 @@ private PlanBuilder planQuantifiedApplyNode(PlanBuilder subPlan, QuantifiedCompa subPlan, quantifiedComparison.getSubquery(), subqueryPlan.getRoot(), - Assignments.of(coercedQuantifiedComparisonVariable, coercedQuantifiedComparison), + Assignments.of(coercedQuantifiedComparisonVariable, castToRowExpression(coercedQuantifiedComparison)), correlationAllowed); } @@ -574,8 +575,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) { ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node); - Assignments assignments = rewrittenNode.getAssignments() - .rewrite(expression -> replaceExpression(expression, mapping)); + Assignments assignments = AssignmentUtils.rewrite(rewrittenNode.getAssignments(), expression -> replaceExpression(expression, mapping)); return new ProjectNode(idAllocator.getNextId(), rewrittenNode.getSource(), assignments); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java index 791dc5ce7fe3d..50276ffac4fe5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolsExtractor.java @@ -29,6 +29,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.Collection; import java.util.List; import java.util.Set; @@ -95,6 +96,23 @@ public static Set extractUniqueVariable(Expression return ImmutableSet.copyOf(extractAllVariable(expression, types)); } + public static Set extractUniqueVariable(RowExpression expression, TypeProvider types) + { + //TODO remove this once we removed all expression from all optimization rules. + // This is used in ValidateDependencyChecker + if (isExpression(expression)) { + return extractUniqueVariable(castToExpression(expression), types); + } + return ImmutableSet.copyOf(extractAll(expression)); + } + + public static Set extractUniqueVariable(Collection expressions, TypeProvider types) + { + return expressions.stream() + .flatMap(expression -> extractUniqueVariable(expression, types).stream()) + .collect(toImmutableSet()); + } + // TODO: return Set public static Set extractUnique(RowExpression expression) { @@ -130,6 +148,7 @@ public static List extractAll(Expression expression) new SymbolBuilderVisitor().process(expression, builder); return builder.build(); } + public static List extractAllVariable(Expression expression, TypeProvider types) { ImmutableList.Builder builder = ImmutableList.builder(); @@ -174,6 +193,7 @@ public static Set extractOutputVariables(PlanNode p .flatMap(node -> node.getOutputVariables().stream()) .collect(toImmutableSet()); } + /** * {@param expression} could be an OriginalExpression */ diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 74864452c9791..03a79ea484cef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -27,6 +27,7 @@ import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -48,6 +49,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -203,7 +205,7 @@ public static PlanNode buildJoinTree(List expectedO result = new ProjectNode( idAllocator.getNextId(), result, - Assignments.copyOf(graph.getAssignments().get())); + Assignments.copyOf(transformValues(graph.getAssignments().get(), OriginalExpressionUtils::castToRowExpression))); } // If needed, introduce a projection to constrain the outputs to what was originally expected diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index 61d2c20633e3a..a5a3735a7ea28 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.AssignmentUtils; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -125,7 +126,7 @@ public Pattern getPattern() @Override public Result apply(ProjectNode projectNode, Captures captures, Context context) { - Assignments assignments = projectNode.getAssignments().rewrite(x -> rewriter.rewrite(x, context)); + Assignments assignments = AssignmentUtils.rewrite(projectNode.getAssignments(), x -> rewriter.rewrite(x, context)); if (projectNode.getAssignments().equals(assignments)) { return Result.empty(); } @@ -325,7 +326,7 @@ public Pattern getPattern() @Override public Result apply(ApplyNode applyNode, Captures captures, Context context) { - Assignments subqueryAssignments = applyNode.getSubqueryAssignments().rewrite(x -> rewriter.rewrite(x, context)); + Assignments subqueryAssignments = AssignmentUtils.rewrite(applyNode.getSubqueryAssignments(), x -> rewriter.rewrite(x, context)); if (applyNode.getSubqueryAssignments().equals(subqueryAssignments)) { return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index f205df0b59be6..cde73ea8eca4b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -591,7 +591,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe projections.put(identityAsSymbolReference(outputVariable)); } - projections.put(variable, expression); + projections.put(variable, castToRowExpression(expression)); return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build()); } @@ -609,7 +609,7 @@ private static PlanNode addPartitioningNodes(Context context, PlanNode node, Var FunctionCall partitioningFunction = new FunctionCall(QualifiedName.of("spatial_partitions"), partitioningArguments.build()); VariableReferenceExpression partitionsVariable = context.getSymbolAllocator().newVariable(partitioningFunction, new ArrayType(INTEGER)); - projections.put(partitionsVariable, partitioningFunction); + projections.put(partitionsVariable, castToRowExpression(partitioningFunction)); return new UnnestNode( context.getIdAllocator().getNextId(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java index 170f6df503d01..2c85635c3194e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/GatherAndMergeWindows.java @@ -18,13 +18,13 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.PropertyPattern; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -141,7 +141,7 @@ protected static Optional pullWindowNodeAboveProjects( // The only kind of use of the output of the target that we can safely ignore is a simple identity propagation. // The target node, when hoisted above the projections, will provide the symbols directly. - Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( + Map assignmentsWithoutTargetOutputIdentities = Maps.filterKeys( project.getAssignments().getMap(), output -> !(isIdentity(project.getAssignments(), output) && targetOutputs.contains(output))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 136505e7a74a5..666a793b8e6f3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -99,7 +99,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont Expression filter = OriginalExpressionUtils.castToExpression(entry.getValue().getFilter().get()); VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(filter, BOOLEAN); verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern"); - newAssignments.put(variable, filter); + newAssignments.put(variable, castToRowExpression(filter)); mask = Optional.of(variable); maskSymbols.add(new SymbolReference(variable.getName())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java index f96c90b6810f0..35e99858c87cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/InlineProjections.java @@ -16,13 +16,16 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolsExtractor; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.Assignments.Builder; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.TryExpression; @@ -41,6 +44,9 @@ import static com.facebook.presto.sql.planner.plan.AssignmentUtils.isIdentity; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.Expressions.variable; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static java.util.stream.Collectors.toSet; /** @@ -75,11 +81,11 @@ public Result apply(ProjectNode parent, Captures captures, Context context) // inline the expressions Assignments assignments = child.getAssignments().filter(targets::contains); - Map parentAssignments = parent.getAssignments() + Map parentAssignments = parent.getAssignments() .entrySet().stream() .collect(Collectors.toMap( Map.Entry::getKey, - entry -> inlineReferences(entry.getValue(), assignments))); + entry -> castToRowExpression(inlineReferences(castToExpression(entry.getValue()), assignments)))); // Synthesize identity assignments for the inputs of expressions that were inlined // to place in the child projection. @@ -89,11 +95,12 @@ public Result apply(ProjectNode parent, Captures captures, Context context) .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) .map(Map.Entry::getValue) + .map(OriginalExpressionUtils::castToExpression) .flatMap(entry -> SymbolsExtractor.extractAllVariable(entry, context.getSymbolAllocator().getTypes()).stream()) .collect(toSet()); - Assignments.Builder childAssignments = Assignments.builder(); - for (Map.Entry assignment : child.getAssignments().entrySet()) { + Builder childAssignments = Assignments.builder(); + for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { childAssignments.put(assignment); } @@ -115,12 +122,10 @@ public Result apply(ProjectNode parent, Captures captures, Context context) private Expression inlineReferences(Expression expression, Assignments assignments) { Function mapping = symbol -> { - Expression result = assignments.get(symbol); - if (result != null) { - return result; + if (assignments.get(symbol) == null) { + return symbol.toSymbolReference(); } - - return symbol.toSymbolReference(); + return castToExpression(assignments.get(symbol)); }; return inlineSymbols(mapping, expression); @@ -139,21 +144,24 @@ private Sets.SetView extractInliningTargets(Project Set childOutputSet = ImmutableSet.copyOf(child.getOutputVariables()); Map dependencies = parent.getAssignments() - .getExpressions().stream() - .flatMap(expression -> SymbolsExtractor.extractAllVariable(expression, context.getSymbolAllocator().getTypes()).stream()) + .getExpressions() + .stream() + .map(OriginalExpressionUtils::castToExpression) + .flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()) + .map(symbol -> variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol))) .filter(childOutputSet::contains) .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); // find references to simple constants Set constants = dependencies.keySet().stream() - .filter(input -> child.getAssignments().get(input) instanceof Literal) + .filter(input -> castToExpression(child.getAssignments().get(input)) instanceof Literal) .collect(toSet()); // exclude any complex inputs to TRY expressions. Inlining them would potentially // change the semantics of those expressions Set tryArguments = parent.getAssignments() .getExpressions().stream() - .flatMap(expression -> extractTryArguments(expression, context.getSymbolAllocator().getTypes()).stream()) + .flatMap(expression -> extractTryArguments(castToExpression(expression), context.getSymbolAllocator().getTypes()).stream()) .collect(toSet()); Set singletons = dependencies.entrySet().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java index f6dc03b107a05..d47acff4ff1ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ProjectOffPushDownRule.java @@ -22,6 +22,7 @@ import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import java.util.Optional; @@ -31,6 +32,7 @@ import static com.facebook.presto.sql.planner.iterative.rule.Util.pruneInputs; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.google.common.collect.ImmutableList.toImmutableList; /** * @param The node type to look for under the ProjectNode @@ -61,7 +63,10 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { N targetNode = captures.get(targetCapture); - return pruneInputs(targetNode.getOutputVariables(), parent.getAssignments().getExpressions(), context.getSymbolAllocator().getTypes()) + return pruneInputs( + targetNode.getOutputVariables(), + parent.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toImmutableList()), + context.getSymbolAllocator().getTypes()) .flatMap(prunedOutputs -> this.pushDownProjectOff(context.getIdAllocator(), context.getSymbolAllocator(), targetNode, prunedOutputs)) .map(newChild -> parent.replaceChildren(ImmutableList.of(newChild))) .map(Result::ofPlanNode) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 6253ea71c6763..181fcc1e75007 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -269,10 +269,13 @@ private Optional coalesceWithNullAggregation(AggregationNode aggregati Assignments.Builder assignmentsBuilder = Assignments.builder(); for (VariableReferenceExpression variable : outerJoin.getOutputVariables()) { if (aggregationNode.getAggregations().keySet().contains(variable)) { - assignmentsBuilder.put(variable, new CoalesceExpression(new SymbolReference(variable.getName()), new SymbolReference(sourceAggregationToOverNullMapping.get(variable).getName()))); + assignmentsBuilder.put(variable, castToRowExpression( + new CoalesceExpression( + new SymbolReference(variable.getName()), + new SymbolReference(sourceAggregationToOverNullMapping.get(variable).getName())))); } else { - assignmentsBuilder.put(variable, new SymbolReference(variable.getName())); + assignmentsBuilder.put(variable, castToRowExpression(new SymbolReference(variable.getName()))); } } return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index cb4f7011e62fc..cef8d92de02ed 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -170,7 +170,7 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, for (VariableReferenceExpression output : aggregation.getOutputVariables()) { VariableReferenceExpression input = symbolMapper.map(output); - assignments.put(output, new SymbolReference(input.getName())); + assignments.put(output, castToRowExpression(new SymbolReference(input.getName()))); } partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 438e3129a86ff..81293565de884 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -17,6 +17,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -26,6 +27,7 @@ import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; @@ -43,6 +45,8 @@ import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; /** * Transforms: @@ -101,14 +105,14 @@ public Result apply(ProjectNode project, Captures captures, Context context) .map(outputToInputMap::get) .forEach(nameReference -> { VariableReferenceExpression variable = toVariableReference(Symbol.from(nameReference), types); - projections.put(variable, nameReference); + projections.put(variable, castToRowExpression(nameReference)); inputs.add(variable); }); if (exchange.getPartitioningScheme().getHashColumn().isPresent()) { // Need to retain the hash symbol for the exchange VariableReferenceExpression hashVariable = exchange.getPartitioningScheme().getHashColumn().get(); - projections.put(hashVariable, new SymbolReference(hashVariable.getName())); + projections.put(hashVariable, castToRowExpression(new SymbolReference(hashVariable.getName()))); inputs.add(hashVariable); } @@ -120,16 +124,16 @@ public Result apply(ProjectNode project, Captures captures, Context context) .map(outputToInputMap::get) .forEach(nameReference -> { VariableReferenceExpression variable = toVariableReference(Symbol.from(nameReference), types); - projections.put(variable, nameReference); + projections.put(variable, castToRowExpression(nameReference)); inputs.add(variable); }); } - for (Map.Entry projection : project.getAssignments().entrySet()) { - Expression translatedExpression = inlineVariables(outputToInputMap, projection.getValue(), types); + for (Map.Entry projection : project.getAssignments().entrySet()) { + Expression translatedExpression = inlineVariables(outputToInputMap, castToExpression(projection.getValue()), types); Type type = projection.getKey().getType(); VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(translatedExpression, type); - projections.put(variable, translatedExpression); + projections.put(variable, castToRowExpression(translatedExpression)); inputs.add(variable); } newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build())); @@ -145,7 +149,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) .filter(variable -> !partitioningColumns.contains(variable)) .forEach(outputBuilder::add); } - for (Map.Entry projection : project.getAssignments().entrySet()) { + for (Map.Entry projection : project.getAssignments().entrySet()) { outputBuilder.add(projection.getKey()); } @@ -172,7 +176,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) private static boolean isSymbolToSymbolProjection(ProjectNode project) { - return project.getAssignments().getExpressions().stream().allMatch(e -> e instanceof SymbolReference); + return project.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).allMatch(e -> e instanceof SymbolReference); } private static Map extractExchangeOutputToInput(ExchangeNode exchange, int sourceIndex) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java index 1c246676a3845..0bae330153fbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughUnion.java @@ -17,8 +17,10 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.ExpressionVariableInliner; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.Assignments; @@ -35,10 +37,11 @@ import java.util.Map; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.sql.planner.ExpressionVariableInliner.inlineVariables; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.union; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; public class PushProjectionThroughUnion implements Rule @@ -76,11 +79,11 @@ public Result apply(ProjectNode parent, Captures captures, Context context) Map projectVariableMapping = new HashMap<>(); // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode - for (Map.Entry entry : parent.getAssignments().entrySet()) { - Expression translatedExpression = inlineVariables(outputToInput, entry.getValue(), context.getSymbolAllocator().getTypes()); + for (Map.Entry entry : parent.getAssignments().entrySet()) { + Expression translatedExpression = ExpressionVariableInliner.inlineVariables(outputToInput, castToExpression(entry.getValue()), context.getSymbolAllocator().getTypes()); Type type = entry.getKey().getType(); VariableReferenceExpression variable = context.getSymbolAllocator().newVariable(translatedExpression, type); - assignments.put(variable, translatedExpression); + assignments.put(variable, castToRowExpression(translatedExpression)); projectVariableMapping.put(new VariableReferenceExpression(entry.getKey().getName(), type), variable); } outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index b4414bef5c0f2..a12e6a4d2423f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -26,7 +26,6 @@ import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.QualifiedName; @@ -95,7 +94,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) { ImmutableMap.Builder aggregations = ImmutableMap.builder(); VariableReferenceExpression partitionCountVariable = context.getSymbolAllocator().newVariable("partition_count", INTEGER); - ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); + ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder(); for (Map.Entry entry : node.getAggregations().entrySet()) { Aggregation aggregation = entry.getValue(); String name = metadata.getFunctionManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName(); @@ -104,10 +103,10 @@ public Result apply(AggregationNode node, Captures captures, Context context) RowExpression geometry = getOnlyElement(aggregation.getArguments()); VariableReferenceExpression envelopeVariable = context.getSymbolAllocator().newVariable("envelope", geometryType); if (isFunctionNameMatch(geometry, "ST_Envelope")) { - envelopeAssignments.put(envelopeVariable, castToExpression(geometry)); + envelopeAssignments.put(envelopeVariable, geometry); } else { - envelopeAssignments.put(envelopeVariable, new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(castToExpression(geometry)))); + envelopeAssignments.put(envelopeVariable, castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(castToExpression(geometry))))); } aggregations.put(entry.getKey(), new Aggregation( @@ -136,7 +135,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getSource(), Assignments.builder() .putAll(identitiesAsSymbolReferences(node.getSource().getOutputVariables())) - .put(partitionCountVariable, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))) + .put(partitionCountVariable, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))) .putAll(envelopeAssignments.build()) .build()), aggregations.build(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 117def48bd12f..4f18025704730 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -119,7 +119,7 @@ private boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Ass RowExpression argument = aggregation.getArguments().get(0); Expression assigned = null; if (castToExpression(argument) instanceof SymbolReference) { - assigned = inputs.get(Symbol.from(castToExpression(argument))); + assigned = castToExpression(inputs.get(Symbol.from(castToExpression(argument)))); } return assigned instanceof Literal && !(assigned instanceof NullLiteral); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 3a79808864f11..cc33cc05737ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -124,7 +124,7 @@ public Result apply(ApplyNode apply, Captures captures, Context context) if (subqueryAssignments.size() != 1) { return Result.empty(); } - Expression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions()); + Expression assignmentExpression = castToExpression(getOnlyElement(subqueryAssignments.getExpressions())); if (!(assignmentExpression instanceof InPredicate)) { return Result.empty(); } @@ -183,7 +183,7 @@ private PlanNode buildInPredicateEquivalent( decorrelatedBuildSource, Assignments.builder() .putAll(identitiesAsSymbolReferences(decorrelatedBuildSource.getOutputVariables())) - .put(buildSideKnownNonNull, bigint(0)) + .put(buildSideKnownNonNull, castToRowExpression(bigint(0))) .build()); SymbolReference probeSideSymbolReference = Symbol.from(inPredicate.getValue()).toSymbolReference(); @@ -233,7 +233,7 @@ private PlanNode buildInPredicateEquivalent( aggregation, Assignments.builder() .putAll(identitiesAsSymbolReferences(apply.getInput().getOutputVariables())) - .put(inPredicateOutputVariable, inPredicateEquivalent) + .put(inPredicateOutputVariable, castToRowExpression(inPredicateEquivalent)) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 0c5ed7bb4c7ee..722e7ad370f28 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -51,6 +51,8 @@ import static com.facebook.presto.sql.planner.plan.LateralJoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; import static com.google.common.base.Preconditions.checkState; @@ -103,7 +105,7 @@ public Result apply(ApplyNode parent, Captures captures, Context context) return Result.empty(); } - Expression expression = getOnlyElement(parent.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(parent.getSubqueryAssignments().getExpressions())); if (!(expression instanceof ExistsPredicate)) { return Result.empty(); } @@ -123,7 +125,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C Assignments.Builder assignments = Assignments.builder(); assignments.putAll(identitiesAsSymbolReferences(applyNode.getInput().getOutputVariables())); - assignments.put(exists, new CoalesceExpression(ImmutableList.of(new SymbolReference(subqueryTrue.getName()), BooleanLiteral.FALSE_LITERAL))); + assignments.put(exists, castToRowExpression(new CoalesceExpression(ImmutableList.of(new SymbolReference(subqueryTrue.getName()), BooleanLiteral.FALSE_LITERAL)))); PlanNode subquery = new ProjectNode( context.getIdAllocator().getNextId(), @@ -132,7 +134,7 @@ private Optional rewriteToNonDefaultAggregation(ApplyNode applyNode, C applyNode.getSubquery(), 1L, false), - Assignments.of(subqueryTrue, TRUE_LITERAL)); + Assignments.of(subqueryTrue, castToRowExpression(TRUE_LITERAL))); PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getSymbolAllocator(), context.getLookup()); if (!decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isPresent()) { @@ -178,7 +180,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), - Assignments.of(exists, new ComparisonExpression(GREATER_THAN, asSymbolReference(count), new Cast(new LongLiteral("0"), BIGINT.toString())))), + Assignments.of(exists, castToRowExpression(new ComparisonExpression(GREATER_THAN, asSymbolReference(count), new Cast(new LongLiteral("0"), BIGINT.toString()))))), parent.getCorrelation(), INNER, parent.getOriginSubqueryError()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java index 362dbc7a2e14f..21dab7342aa50 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -28,6 +28,7 @@ import static com.facebook.presto.matching.Pattern.empty; import static com.facebook.presto.sql.planner.plan.Patterns.Apply.correlation; import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.google.common.collect.Iterables.getOnlyElement; /** @@ -72,7 +73,7 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) return Result.empty(); } - Expression expression = getOnlyElement(applyNode.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(applyNode.getSubqueryAssignments().getExpressions())); if (!(expression instanceof InPredicate)) { return Result.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index bc72830d47bad..f683b8961cdca 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; @@ -1375,9 +1376,9 @@ private boolean canPushdownPartialMergeThroughLowMemoryOperators(PlanNode node) public static Map computeIdentityTranslations(Assignments assignments, TypeProvider types) { Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), toVariableReference(Symbol.from(assignment.getValue()), types)); + for (Map.Entry assignment : assignments.getMap().entrySet()) { + if (castToExpression(assignment.getValue()) instanceof SymbolReference) { + outputToInput.put(assignment.getKey(), toVariableReference(Symbol.from(castToExpression(assignment.getValue())), types)); } } return outputToInput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java index 348871a467fa3..7f914e0b7b46c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyNodeUtil.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.InPredicate; @@ -28,7 +29,7 @@ private ApplyNodeUtil() {} public static void verifySubquerySupported(Assignments assignments) { checkArgument( - assignments.getExpressions().stream().allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), + assignments.getExpressions().stream().map(OriginalExpressionUtils::castToExpression).allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), "Unexpected expression used for subquery expression"); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index ea81e3d7df522..ce2d64ddbb799 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; @@ -78,6 +79,9 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -623,7 +627,7 @@ public PlanWithProperties visitProject(ProjectNode node, HashComputationSet pare else { hashExpression = new SymbolReference(hashVariable.getName()); } - newAssignments.put(hashVariable, hashExpression); + newAssignments.put(hashVariable, castToRowExpression(hashExpression)); allHashVariables.put(hashComputation, hashVariable); } @@ -720,7 +724,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo for (VariableReferenceExpression variable : planWithProperties.getNode().getOutputVariables()) { HashComputation partitionVariables = resultHashVariables.get(variable); if (partitionVariables == null || requiredHashes.getHashes().contains(partitionVariables)) { - assignments.put(variable, new SymbolReference(variable.getName())); + assignments.put(variable, castToRowExpression(asSymbolReference(variable))); if (partitionVariables != null) { outputHashVariables.put(partitionVariables, planWithProperties.getHashVariables().get(partitionVariables)); @@ -733,7 +737,7 @@ private PlanWithProperties enforce(PlanWithProperties planWithProperties, HashCo if (!planWithProperties.getHashVariables().containsKey(hashComputation)) { Expression hashExpression = hashComputation.getHashExpression(); VariableReferenceExpression hashVariable = symbolAllocator.newHashVariable(); - assignments.put(hashVariable, hashExpression); + assignments.put(hashVariable, castToRowExpression(hashExpression)); outputHashVariables.put(hashComputation, hashVariable); } } @@ -962,12 +966,12 @@ public VariableReferenceExpression getRequiredHashVariable(HashComputation hash) } } - private static Map computeIdentityTranslations(Map assignments) + private static Map computeIdentityTranslations(Map assignments) { Map outputToInput = new HashMap<>(); - for (Map.Entry assignment : assignments.entrySet()) { - if (assignment.getValue() instanceof SymbolReference) { - outputToInput.put(assignment.getKey(), new VariableReferenceExpression(((SymbolReference) assignment.getValue()).getName(), assignment.getKey().getType())); + for (Map.Entry assignment : assignments.entrySet()) { + if (castToExpression(assignment.getValue()) instanceof SymbolReference) { + outputToInput.put(assignment.getKey(), new VariableReferenceExpression(((SymbolReference) castToExpression(assignment.getValue())).getName(), assignment.getKey().getType())); } } return outputToInput; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index c2447078a22e5..668ff97d2a1f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -220,13 +220,13 @@ private PlanNode appendMarkers(PlanNode source, int markerIndex, List entry : projections.entrySet()) { VariableReferenceExpression variable = symbolAllocator.newVariable(entry.getKey().getName(), entry.getKey().getType()); - assignments.put(variable, entry.getValue()); + assignments.put(variable, castToRowExpression(entry.getValue())); } // add extra marker fields to the projection for (int i = 0; i < markers.size(); ++i) { Expression expression = (i == markerIndex) ? TRUE_LITERAL : new Cast(new NullLiteral(), StandardTypes.BOOLEAN); - assignments.put(symbolAllocator.newVariable(markers.get(i).getName(), BOOLEAN), expression); + assignments.put(symbolAllocator.newVariable(markers.get(i).getName(), BOOLEAN), castToRowExpression(expression)); } return new ProjectNode(idAllocator.getNextId(), source, assignments.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index 15de6d73ff02e..34e536d2797d8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.predicate.TupleDomain; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.ExpressionDomainTranslator; import com.facebook.presto.sql.planner.LiteralEncoder; @@ -60,6 +61,7 @@ import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -172,14 +174,14 @@ else if (leftIndexCandidate.isPresent()) { case LEFT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && rightIndexCandidate.isPresent()) { - return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinVariables, rightJoinVariables), idAllocator, symbolAllocator); + return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), createEquiJoinClause(leftJoinVariables, rightJoinVariables), idAllocator); } break; case RIGHT: // We cannot use indices for outer joins until index join supports in-line filtering if (!node.getFilter().isPresent() && leftIndexCandidate.isPresent()) { - return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinVariables, leftJoinVariables), idAllocator, symbolAllocator); + return createIndexJoinWithExpectedOutputs(node.getOutputVariables(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), createEquiJoinClause(rightJoinVariables, leftJoinVariables), idAllocator); } break; @@ -203,8 +205,7 @@ private static PlanNode createIndexJoinWithExpectedOutputs( PlanNode probe, PlanNode index, List equiJoinClause, - PlanNodeIdAllocator idAllocator, - SymbolAllocator symbolAllocator) + PlanNodeIdAllocator idAllocator) { PlanNode result = new IndexJoinNode(idAllocator.getNextId(), type, probe, index, equiJoinClause, Optional.empty(), Optional.empty()); if (!result.getOutputVariables().equals(expectedOutputs)) { @@ -335,9 +336,9 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) // Rewrite the lookup variables in terms of only the pre-projected variables that have direct translations ImmutableSet.Builder newLookupVariablesBuilder = ImmutableSet.builder(); for (VariableReferenceExpression variable : context.get().getLookupVariables()) { - Expression expression = node.getAssignments().get(variable); - if (expression instanceof SymbolReference) { - newLookupVariablesBuilder.add(new VariableReferenceExpression(((SymbolReference) expression).getName(), variable.getType())); + RowExpression expression = node.getAssignments().get(variable); + if (castToExpression(expression) instanceof SymbolReference) { + newLookupVariablesBuilder.add(new VariableReferenceExpression(((SymbolReference) castToExpression(expression)).getName(), variable.getType())); } } ImmutableSet newLookupVariables = newLookupVariablesBuilder.build(); @@ -496,8 +497,12 @@ public Map visitPlan(P @Override public Map visitProject(ProjectNode node, Set lookupVariables) { - // Map from output variables to source variables - Map directSymbolTranslationOutputMap = Maps.transformValues(Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from); + // Map from output Symbols to source Symbols + Map directSymbolTranslationOutputMap = Maps.transformValues( + Maps.filterValues( + node.getAssignments().getMap(), + IndexKeyTracer::isVariable), + this::extractSymbol); Map outputToSourceMap = lookupVariables.stream() .filter(directSymbolTranslationOutputMap.keySet()::contains) .collect(toImmutableMap(identity(), variable -> new VariableReferenceExpression(directSymbolTranslationOutputMap.get(variable).getName(), variable.getType()))); @@ -560,6 +565,27 @@ public Map visitIndexS checkState(node.getLookupVariables().equals(lookupVariables), "lookupVariables must be the same as IndexSource lookup variables"); return lookupVariables.stream().collect(toImmutableMap(identity(), identity())); } + + private Symbol extractSymbol(RowExpression expression) + { + // TODO remove isExpression once all optimization rule is using RowExpression. + // Handle both expression and rowExpression because ValidateDependenciesChecker used it. + if (expression instanceof VariableReferenceExpression) { + return new Symbol(((VariableReferenceExpression) expression).getName()); + } + checkArgument(isExpression(expression), "Must either be VariableReference or SymbolReference"); + return Symbol.from(castToExpression(expression)); + } + } + + private static boolean isVariable(RowExpression expression) + { + // TODO remove isExpression once all optimization rule is using RowExpression. + // Handle both expression and rowExpression because ValidateDependenciesChecker used it. + if (isExpression(expression)) { + return castToExpression(expression) instanceof SymbolReference; + } + return expression instanceof VariableReferenceExpression; } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index cacaf2aaf83e9..840615bd2caee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -42,6 +42,7 @@ import com.facebook.presto.sql.planner.plan.SortNode; import com.facebook.presto.sql.planner.plan.TopNNode; import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -54,6 +55,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; /** * Converts cardinality-insensitive aggregations (max, min, "distinct") over partition keys @@ -190,7 +192,7 @@ private static Optional findTableScan(PlanNode source) else if (source instanceof ProjectNode) { // verify projections are deterministic ProjectNode project = (ProjectNode) source; - if (!Iterables.all(project.getAssignments().getExpressions(), ExpressionDeterminismEvaluator::isDeterministic)) { + if (!Iterables.all(project.getAssignments().getExpressions().stream().map(OriginalExpressionUtils::castToExpression).collect(toList()), ExpressionDeterminismEvaluator::isDeterministic)) { return Optional.empty(); } source = project.getSource(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 82df34d78c6e0..7329c93e27204 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -223,7 +223,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) // pre-projected symbols. Predicate isSupported = conjunct -> ExpressionDeterminismEvaluator.isDeterministic(conjunct) && - SymbolsExtractor.extractUniqueVariable(conjunct, types).stream() - .allMatch(node.getPartitionBy()::contains); + SymbolsExtractor.extractUniqueVariable(conjunct, types) + .stream() + .allMatch(node.getPartitionBy()::contains); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); @@ -244,7 +246,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) public PlanNode visitProject(ProjectNode node, RewriteContext context) { Set deterministicVariables = node.getAssignments().entrySet().stream() - .filter(entry -> ExpressionDeterminismEvaluator.isDeterministic(entry.getValue())) + .filter(entry -> ExpressionDeterminismEvaluator.isDeterministic(castToExpression(entry.getValue()))) .map(Map.Entry::getKey) .collect(Collectors.toSet()); @@ -262,7 +264,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext contex .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() - .map(entry -> inlineVariables(node.getAssignments().getMap(), entry, types)) + .map(entry -> ExpressionVariableInliner.inlineVariables(Maps.transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression), entry, types)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); @@ -296,7 +298,7 @@ private boolean isInliningCandidate(Expression expression, ProjectNode node) .collect(Collectors.groupingBy(identity(), Collectors.counting())); return dependencies.entrySet().stream() - .allMatch(entry -> entry.getValue() == 1 || node.getAssignments().get(entry.getKey()) instanceof Literal); + .allMatch(entry -> entry.getValue() == 1 || castToExpression(node.getAssignments().get(entry.getKey())) instanceof Literal); } @Override @@ -457,14 +459,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Create identity projections for all existing symbols Assignments.Builder leftProjections = Assignments.builder(); - leftProjections.putAll(node.getLeft() - .getOutputVariables().stream() - .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); + leftProjections.putAll(identityAssignmentsAsSymbolReferences(node.getLeft() + .getOutputVariables())); Assignments.Builder rightProjections = Assignments.builder(); - rightProjections.putAll(node.getRight() - .getOutputVariables().stream() - .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); + rightProjections.putAll(identityAssignmentsAsSymbolReferences(node.getRight() + .getOutputVariables())); // Create new projections for the new join clauses List equiJoinClauses = new ArrayList<>(); @@ -479,12 +479,12 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) VariableReferenceExpression leftVariable = variableForExpression(leftExpression); if (!node.getLeft().getOutputVariables().contains(leftVariable)) { - leftProjections.put(leftVariable, leftExpression); + leftProjections.put(leftVariable, castToRowExpression(leftExpression)); } VariableReferenceExpression rightVariable = variableForExpression(rightExpression); if (!node.getRight().getOutputVariables().contains(rightVariable)) { - rightProjections.put(rightVariable, rightExpression); + rightProjections.put(rightVariable, castToRowExpression(rightExpression)); } equiJoinClauses.add(new JoinNode.EquiJoinClause(leftVariable, rightVariable)); @@ -627,14 +627,12 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext new SymbolReference(variable.getName())))); + leftProjections.putAll(identityAssignmentsAsSymbolReferences(node.getLeft() + .getOutputVariables())); Assignments.Builder rightProjections = Assignments.builder(); - rightProjections.putAll(node.getRight() - .getOutputVariables().stream() - .collect(Collectors.toMap(identity(), variable -> new SymbolReference(variable.getName())))); + rightProjections.putAll(identityAssignmentsAsSymbolReferences(node.getRight() + .getOutputVariables())); leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index f92c589a67593..aa1b6538b4f51 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.parser.SqlParser; @@ -36,6 +37,7 @@ import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.NoOpSymbolResolver; import com.facebook.presto.sql.planner.OrderingScheme; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.ActualProperties.Global; @@ -626,29 +628,45 @@ public ActualProperties visitProject(ProjectNode node, List in // Extract additional constants Map constants = new HashMap<>(); - for (Map.Entry assignment : node.getAssignments().entrySet()) { - Expression expression = assignment.getValue(); + for (Map.Entry assignment : node.getAssignments().entrySet()) { + RowExpression expression = assignment.getValue(); VariableReferenceExpression output = assignment.getKey(); - Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, expression, emptyList(), WarningCollector.NOOP); - Type type = requireNonNull(expressionTypes.get(NodeRef.of(expression))); - ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes); // TODO: // We want to use a symbol resolver that looks up in the constants from the input subplan // to take advantage of constant-folding for complex expressions // However, that currently causes errors when those expressions operate on arrays or row types // ("ROW comparison not supported for fields with null elements", etc) - Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); - - if (value instanceof SymbolReference) { - VariableReferenceExpression variable = toVariableReference(Symbol.from((SymbolReference) value), types); - ConstantExpression existingConstantValue = constants.get(variable); - if (existingConstantValue != null) { - constants.put(output, new ConstantExpression(value, type)); + if (isExpression(expression)) { + Map, Type> expressionTypes = getExpressionTypes(session, metadata, parser, types, castToExpression(expression), emptyList(), WarningCollector.NOOP); + Type type = requireNonNull(expressionTypes.get(NodeRef.of(castToExpression(expression)))); + ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(castToExpression(expression), metadata, session, expressionTypes); + Object value = optimizer.optimize(NoOpSymbolResolver.INSTANCE); + + if (value instanceof SymbolReference) { + Symbol symbol = Symbol.from((SymbolReference) value); + ConstantExpression existingConstantValue = constants.get(symbol); + if (existingConstantValue != null) { + constants.put(assignment.getKey(), new ConstantExpression(value, type)); + } + } + else if (!(value instanceof Expression)) { + constants.put(assignment.getKey(), new ConstantExpression(value, type)); } } - else if (!(value instanceof Expression)) { - constants.put(output, new ConstantExpression(value, type)); + else { + Object value = new RowExpressionInterpreter(expression, metadata, session.toConnectorSession(), true).optimize(); + + if (value instanceof VariableReferenceExpression) { + Symbol symbol = new Symbol(((VariableReferenceExpression) value).getName()); + ConstantExpression existingConstantValue = constants.get(symbol); + if (existingConstantValue != null) { + constants.put(assignment.getKey(), new ConstantExpression(value, ((VariableReferenceExpression) value).getType())); + } + } + else if (!(value instanceof RowExpression)) { + constants.put(assignment.getKey(), new ConstantExpression(value, expression.getType())); + } } } constants.putAll(translatedProperties.getConstants()); @@ -776,6 +794,25 @@ private static Optional> translateToNonConstan return Optional.of(ImmutableList.copyOf(builder.build())); } + + private static Map computeIdentityTranslations(Map assignments, TypeProvider types) + { + Map inputToOutput = new HashMap<>(); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + inputToOutput.put(toVariableReference(Symbol.from(castToExpression(expression)), types), assignment.getKey()); + } + } + else { + if (expression instanceof VariableReferenceExpression) { + inputToOutput.put((VariableReferenceExpression) expression, assignment.getKey()); + } + } + } + return inputToOutput; + } } private static Map computeIdentityTranslations(Map assignments, TypeProvider types) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 9ac34447c8a27..bbf853cf971ec 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -85,6 +85,7 @@ import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -539,7 +540,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext { if (context.get().contains(variable)) { - expectedInputs.addAll(SymbolsExtractor.extractUniqueVariable(expression, symbolAllocator.getTypes())); + expectedInputs.addAll(SymbolsExtractor.extractUniqueVariable(castToExpression(expression), symbolAllocator.getTypes())); builder.put(variable, expression); } }); @@ -795,12 +796,12 @@ public PlanNode visitApply(ApplyNode node, RewriteContext subqueryAssignmentsVariablesBuilder = ImmutableSet.builder(); Assignments.Builder subqueryAssignments = Assignments.builder(); - for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { + for (Map.Entry entry : node.getSubqueryAssignments().getMap().entrySet()) { VariableReferenceExpression output = entry.getKey(); - Expression expression = entry.getValue(); + Expression expression = castToExpression(entry.getValue()); if (context.get().contains(output)) { subqueryAssignmentsVariablesBuilder.addAll(SymbolsExtractor.extractUniqueVariable(expression, symbolAllocator.getTypes())); - subqueryAssignments.put(output, expression); + subqueryAssignments.put(output, castToRowExpression(expression)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index 5c7904660be23..1b21daa02b0a0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.RowType; @@ -222,9 +223,9 @@ public PlanNode visitOutput(OutputNode node, RewriteContext context) @Override public PlanNode visitProject(ProjectNode node, RewriteContext context) { - for (Map.Entry entry : node.getAssignments().entrySet()) { + for (Map.Entry entry : node.getAssignments().entrySet()) { VariableReferenceExpression variable = entry.getKey(); - Expression expression = entry.getValue(); + Expression expression = castToExpression(entry.getValue()); if (expression instanceof SymbolReference) { context.get().addAssignment(variable, new VariableReferenceExpression(((SymbolReference) expression).getName(), types.get(Symbol.from(expression)))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java index 3faffde8063ea..5943c0467ebb2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ScalarAggregationToJoinRewriter.java @@ -84,7 +84,7 @@ public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, Aggreg VariableReferenceExpression nonNull = symbolAllocator.newVariable("non_null", BooleanType.BOOLEAN); Assignments scalarAggregationSourceAssignments = Assignments.builder() .putAll(identitiesAsSymbolReferences(source.get().getNode().getOutputVariables())) - .put(nonNull, TRUE_LITERAL) + .put(nonNull, castToRowExpression(TRUE_LITERAL)) .build(); ProjectNode scalarAggregationSourceWithNonNullableVariable = new ProjectNode( idAllocator.getNextId(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 71425d02f1bfb..e3e3aa7e57852 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -21,9 +21,11 @@ import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Partitioning.ArgumentBinding; +import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -57,6 +59,7 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -76,11 +79,13 @@ import java.util.stream.Collectors; import static com.facebook.presto.spi.predicate.TupleDomain.extractFixedValuesToConstantExpressions; +import static com.facebook.presto.sql.planner.PlannerUtils.toVariableReference; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; -import static com.facebook.presto.sql.planner.optimizations.AddExchanges.computeIdentityTranslations; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.FIXED; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.MULTIPLE; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.SINGLE; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -332,11 +337,30 @@ public StreamProperties visitProject(ProjectNode node, List in StreamProperties properties = Iterables.getOnlyElement(inputProperties); // We can describe properties in terms of inputs that are projected unmodified (i.e., identity projections) - Map identities = computeIdentityTranslations(node.getAssignments(), types); + Map identities = computeIdentityTranslations(node.getAssignments().getMap(), types); return properties.translate(column -> Optional.ofNullable(identities.get(column))); } + private static Map computeIdentityTranslations(Map assignments, TypeProvider types) + { + Map inputToOutput = new HashMap<>(); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + inputToOutput.put(toVariableReference(Symbol.from(castToExpression(expression)), types), assignment.getKey()); + } + } + else { + if (expression instanceof VariableReferenceExpression) { + inputToOutput.put((VariableReferenceExpression) expression, assignment.getKey()); + } + } + } + return inputToOutput; + } + @Override public StreamProperties visitGroupId(GroupIdNode node, List inputProperties) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index f4ef3278f8b83..f1716a184d829 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -61,6 +61,7 @@ import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -114,7 +115,7 @@ public PlanNode visitApply(ApplyNode node, RewriteContext context) return context.defaultRewrite(node); } - Expression expression = getOnlyElement(node.getSubqueryAssignments().getExpressions()); + Expression expression = castToExpression(getOnlyElement(node.getSubqueryAssignments().getExpressions())); if (!(expression instanceof QuantifiedComparisonExpression)) { return context.defaultRewrite(node); } @@ -206,7 +207,7 @@ countNonNullValue, new Aggregation( VariableReferenceExpression quantifiedComparisonVariable = getOnlyElement(node.getSubqueryAssignments().getVariables()); - return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonVariable, valueComparedToSubquery)); + return projectExpressions(lateralJoinNode, Assignments.of(quantifiedComparisonVariable, castToRowExpression(valueComparedToSubquery))); } public Expression rewriteUsingBounds(QuantifiedComparisonExpression quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java index ef83cdecb90a6..e4a9b019e81a4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TranslateExpressions.java @@ -30,7 +30,10 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; +import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SpatialJoinNode; import com.facebook.presto.sql.planner.plan.StatisticAggregations; import com.facebook.presto.sql.planner.plan.TableFinishNode; @@ -59,8 +62,10 @@ import static com.facebook.presto.execution.warnings.WarningCollector.NOOP; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.planner.plan.Patterns.applyNode; import static com.facebook.presto.sql.planner.plan.Patterns.filter; import static com.facebook.presto.sql.planner.plan.Patterns.join; +import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.spatialJoin; import static com.facebook.presto.sql.planner.plan.Patterns.tableFinish; import static com.facebook.presto.sql.planner.plan.Patterns.tableWriterNode; @@ -93,6 +98,8 @@ public Set> rules() return ImmutableSet.of( new ValuesExpressionTranslation(), new FilterExpressionTranslation(), + new ProjectExpressionTranslation(), + new ApplyExpressionTranslation(), new WindowExpressionTranslation(), new JoinExpressionTranslation(), new SpatialJoinExpressionTranslation(), @@ -217,6 +224,56 @@ public Result apply(WindowNode windowNode, Captures captures, Context context) } } + private final class ProjectExpressionTranslation + implements Rule + { + @Override + public Pattern getPattern() + { + return project(); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + Assignments assignments = projectNode.getAssignments(); + Optional rewrittenAssignments = translateAssignments(assignments, context); + + if (!rewrittenAssignments.isPresent()) { + return Result.empty(); + } + return Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), rewrittenAssignments.get())); + } + } + + private final class ApplyExpressionTranslation + implements Rule + { + @Override + public Pattern getPattern() + { + return applyNode(); + } + + @Override + public Result apply(ApplyNode applyNode, Captures captures, Context context) + { + Assignments assignments = applyNode.getSubqueryAssignments(); + Optional rewrittenAssignments = translateAssignments(assignments, context); + + if (!rewrittenAssignments.isPresent()) { + return Result.empty(); + } + return Result.ofPlanNode(new ApplyNode( + applyNode.getId(), + applyNode.getInput(), + applyNode.getSubquery(), + rewrittenAssignments.get(), + applyNode.getCorrelation(), + applyNode.getOriginSubqueryError())); + } + } + private final class FilterExpressionTranslation implements Rule { @@ -523,4 +580,33 @@ private RowExpression removeOriginalExpression(RowExpression rowExpression, Sess } return rowExpression; } + + /** + * Return Optional.empty() to denote unchanged assignments + */ + private Optional translateAssignments(Assignments assignments, Rule.Context context) + { + Assignments.Builder builder = Assignments.builder(); + boolean anyRewritten = false; + for (Map.Entry entry : assignments.entrySet()) { + RowExpression expression = entry.getValue(); + RowExpression rewritten; + if (isExpression(expression)) { + rewritten = toRowExpression( + castToExpression(expression), + context.getSession(), + analyze(castToExpression(expression), context.getSession(), context.getSymbolAllocator().getTypes())); + anyRewritten = true; + } + else { + rewritten = expression; + } + builder.put(entry.getKey(), rewritten); + } + if (!anyRewritten) { + return Optional.empty(); + } + + return Optional.of(builder.build()); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index ea03cecd671d1..14fde53e22bd6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -637,8 +637,8 @@ private Assignments canonicalize(Assignments oldAssignments) { Map computedExpressions = new HashMap<>(); Assignments.Builder assignments = Assignments.builder(); - for (Map.Entry entry : oldAssignments.getMap().entrySet()) { - Expression expression = canonicalize(entry.getValue()); + for (Map.Entry entry : oldAssignments.getMap().entrySet()) { + Expression expression = canonicalize(castToExpression(entry.getValue())); if (expression instanceof SymbolReference) { // Always map a trivial symbol projection @@ -663,7 +663,7 @@ else if (ExpressionDeterminismEvaluator.isDeterministic(expression) && !(express } VariableReferenceExpression canonical = canonicalize(entry.getKey()); - assignments.put(canonical, expression); + assignments.put(canonical, castToRowExpression(expression)); } return assignments.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java index ba970a88baf06..0d053d47533c2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/joins/JoinGraph.java @@ -22,6 +22,7 @@ import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; @@ -39,6 +40,7 @@ import static com.facebook.presto.sql.relational.ProjectNodeUtils.isIdentity; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.transformValues; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -281,7 +283,7 @@ public JoinGraph visitProject(ProjectNode node, Context context) { if (isIdentity(node)) { JoinGraph graph = node.getSource().accept(this, context); - return graph.withAssignments(node.getAssignments().getMap()); + return graph.withAssignments(transformValues(node.getAssignments().getMap(), OriginalExpressionUtils::castToExpression)); } return visitPlan(node, context); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index 06d9d4ec4217a..8364bb2457684 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -84,7 +85,7 @@ public ApplyNode( checkArgument(input.getOutputVariables().containsAll(correlation), "Input does not contain symbols from correlation"); checkArgument( - subqueryAssignments.getExpressions().stream().allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), + subqueryAssignments.getExpressions().stream().map(OriginalExpressionUtils::castToExpression).allMatch(ApplyNodeUtil::isSupportedSubqueryExpression), "Unexpected expression used for subquery expression"); this.input = input; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java index 281062e5f7bc1..036e0f7758dd7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AssignmentUtils.java @@ -13,15 +13,21 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.Maps; import java.util.Collection; import java.util.LinkedHashMap; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collector; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static java.util.Arrays.asList; import static java.util.Collections.singletonMap; @@ -30,18 +36,18 @@ public class AssignmentUtils private AssignmentUtils() {} @Deprecated - public static Map.Entry identityAsSymbolReference(VariableReferenceExpression variable) + public static Map.Entry identityAsSymbolReference(VariableReferenceExpression variable) { - return singletonMap(variable, asSymbolReference(variable)) + return singletonMap(variable, castToRowExpression(asSymbolReference(variable))) .entrySet().iterator().next(); } @Deprecated - public static Map identitiesAsSymbolReferences(Collection variables) + public static Map identitiesAsSymbolReferences(Collection variables) { - Map map = new LinkedHashMap<>(); + Map map = new LinkedHashMap<>(); for (VariableReferenceExpression variable : variables) { - map.put(variable, asSymbolReference(variable)); + map.put(variable, castToRowExpression(asSymbolReference(variable))); } return map; } @@ -52,10 +58,17 @@ public static Assignments identityAssignmentsAsSymbolReferences(Collection variables) + { + Assignments.Builder builder = Assignments.builder(); + variables.forEach(variable -> builder.put(variable, variable)); + return builder.build(); + } + public static boolean isIdentity(Assignments assignments, VariableReferenceExpression output) { //TODO this will be checking against VariableExpression once getOutput returns VariableReferenceExpression - Expression expression = assignments.get(output); + Expression expression = castToExpression(assignments.get(output)); return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); } @@ -64,4 +77,23 @@ public static Assignments identityAssignmentsAsSymbolReferences(VariableReferenc { return identityAssignmentsAsSymbolReferences(asList(variables)); } + + public static Assignments rewrite(Assignments assignments, Function rewrite) + { + return assignments.entrySet().stream() + .map(entry -> Maps.immutableEntry(entry.getKey(), castToRowExpression(rewrite.apply(castToExpression(entry.getValue()))))) + .collect(toAssignments()); + } + + private static Collector, Assignments.Builder, Assignments> toAssignments() + { + return Collector.of( + Assignments::builder, + (builder, entry) -> builder.put(entry.getKey(), entry.getValue()), + (left, right) -> { + left.putAll(right.build()); + return left; + }, + Assignments.Builder::build); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java index 799762cc0e990..e934416c41b31 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/Assignments.java @@ -13,17 +13,13 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.ExpressionRewriter; -import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; import java.util.Collection; import java.util.LinkedHashMap; @@ -32,7 +28,7 @@ import java.util.Map.Entry; import java.util.Set; import java.util.function.BiConsumer; -import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collector; import static com.google.common.base.Preconditions.checkState; @@ -47,7 +43,12 @@ public static Builder builder() return new Builder(); } - public static Assignments copyOf(Map assignments) + public static Builder builder(Map assignments) + { + return new Builder().putAll(assignments); + } + + public static Assignments copyOf(Map assignments) { return builder() .putAll(assignments) @@ -59,20 +60,20 @@ public static Assignments of() return builder().build(); } - public static Assignments of(VariableReferenceExpression variable, Expression expression) + public static Assignments of(VariableReferenceExpression variable, RowExpression expression) { return builder().put(variable, expression).build(); } - public static Assignments of(VariableReferenceExpression variable1, Expression expression1, VariableReferenceExpression variable2, Expression expression2) + public static Assignments of(VariableReferenceExpression variable1, RowExpression expression1, VariableReferenceExpression variable2, RowExpression expression2) { return builder().put(variable1, expression1).put(variable2, expression2).build(); } - private final Map assignments; + private final Map assignments; @JsonCreator - public Assignments(@JsonProperty("assignments") Map assignments) + public Assignments(@JsonProperty("assignments") Map assignments) { this.assignments = ImmutableMap.copyOf(requireNonNull(assignments, "assignments is null")); } @@ -83,23 +84,11 @@ public List getOutputs() } @JsonProperty("assignments") - public Map getMap() + public Map getMap() { return assignments; } - public Assignments rewrite(ExpressionRewriter rewriter) - { - return rewrite(expression -> ExpressionTreeRewriter.rewriteWith(rewriter, expression)); - } - - public Assignments rewrite(Function rewrite) - { - return assignments.entrySet().stream() - .map(entry -> Maps.immutableEntry(entry.getKey(), rewrite.apply(entry.getValue()))) - .collect(toAssignments()); - } - public Assignments filter(Collection variables) { return filter(variables::contains); @@ -108,11 +97,11 @@ public Assignments filter(Collection variables) public Assignments filter(Predicate predicate) { return assignments.entrySet().stream() - .filter(entry -> predicate.apply(entry.getKey())) + .filter(entry -> predicate.test(entry.getKey())) .collect(toAssignments()); } - private Collector, Builder, Assignments> toAssignments() + private Collector, Builder, Assignments> toAssignments() { return Collector.of( Assignments::builder, @@ -121,10 +110,10 @@ private Collector, Builder, Assig left.putAll(right.build()); return left; }, - Assignments.Builder::build); + Builder::build); } - public Collection getExpressions() + public Collection getExpressions() { return assignments.values(); } @@ -139,19 +128,19 @@ public Set getVariables() return assignments.keySet(); } - public Set> entrySet() + public Set> entrySet() { return assignments.entrySet(); } - public Expression get(VariableReferenceExpression variable) + public RowExpression get(VariableReferenceExpression variable) { return assignments.get(variable); } - public Expression get(Symbol symbol) + public RowExpression get(Symbol symbol) { - List candidate = assignments.entrySet().stream() + List candidate = assignments.entrySet().stream() .filter(entry -> entry.getKey().getName().equals(symbol.getName())) .map(Entry::getValue) .collect(toImmutableList()); @@ -171,7 +160,7 @@ public boolean isEmpty() return size() == 0; } - public void forEach(BiConsumer consumer) + public void forEach(BiConsumer consumer) { assignments.forEach(consumer); } @@ -199,25 +188,25 @@ public int hashCode() public static class Builder { - private final Map assignments = new LinkedHashMap<>(); + private final Map assignments = new LinkedHashMap<>(); public Builder putAll(Assignments assignments) { return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Entry assignment : assignments.entrySet()) { + for (Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; } - public Builder put(VariableReferenceExpression variable, Expression expression) + public Builder put(VariableReferenceExpression variable, RowExpression expression) { if (assignments.containsKey(variable)) { - Expression assignment = assignments.get(variable); + RowExpression assignment = assignments.get(variable); checkState( assignment.equals(expression), "Variable %s already has assignment %s, while adding %s", @@ -229,7 +218,7 @@ public Builder put(VariableReferenceExpression variable, Expression expression) return this; } - public Builder put(Entry assignment) + public Builder put(Entry assignment) { put(assignment.getKey(), assignment.getValue()); return this; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 435a190d56f3d..b62bcd4a99eb8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -1056,12 +1056,12 @@ private Void processChildren(PlanNode node, Void context) private void printAssignments(NodeRepresentation nodeOutput, Assignments assignments) { - for (Map.Entry entry : assignments.getMap().entrySet()) { - if (entry.getValue() instanceof SymbolReference && ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + for (Map.Entry entry : assignments.getMap().entrySet()) { + if (entry.getValue() instanceof VariableReferenceExpression && ((VariableReferenceExpression) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments continue; } - nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), entry.getValue()); + nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), formatter.formatRowExpression(entry.getValue())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index a356fa2a80da6..23d626c3b8fc2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -42,6 +43,7 @@ import java.util.Map; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.Preconditions.checkArgument; @@ -116,15 +118,22 @@ public Void visitProject(ProjectNode node, Void context) { visitPlan(node, context); - for (Map.Entry entry : node.getAssignments().entrySet()) { - if (entry.getValue() instanceof SymbolReference) { - SymbolReference symbolReference = (SymbolReference) entry.getValue(); - verifyTypeSignature(entry.getKey(), types.get(Symbol.from(symbolReference)).getTypeSignature()); - continue; + for (Map.Entry entry : node.getAssignments().entrySet()) { + RowExpression expression = entry.getValue(); + if (isExpression(expression)) { + if (castToExpression(expression) instanceof SymbolReference) { + SymbolReference symbolReference = (SymbolReference) castToExpression(expression); + verifyTypeSignature(entry.getKey(), types.get(Symbol.from(symbolReference)).getTypeSignature()); + continue; + } + Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, castToExpression(expression), emptyList(), warningCollector); + Type actualType = expressionTypes.get(NodeRef.of(castToExpression(expression))); + verifyTypeSignature(entry.getKey(), actualType.getTypeSignature()); + } + else { + Type actualType = expression.getType(); + verifyTypeSignature(entry.getKey(), actualType.getTypeSignature()); } - Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, entry.getValue(), emptyList(), warningCollector); - Type actualType = expressionTypes.get(NodeRef.of(entry.getValue())); - verifyTypeSignature(entry.getKey(), actualType.getTypeSignature()); } return null; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index eb2deda36ecfe..65e58e6a4ac73 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.sql.parser.SqlParser; @@ -64,7 +65,6 @@ import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.ValuesNode; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.Expression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -278,7 +278,7 @@ public Void visitProject(ProjectNode node, Set boun source.accept(this, boundVariables); // visit child Set inputs = createInputs(source, boundVariables); - for (Expression expression : node.getAssignments().getExpressions()) { + for (RowExpression expression : node.getAssignments().getExpressions()) { Set dependencies = SymbolsExtractor.extractUniqueVariable(expression, types); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } @@ -674,7 +674,7 @@ public Void visitApply(ApplyNode node, Set boundVar .addAll(createInputs(node.getInput(), boundVariables)) .build(); - for (Expression expression : node.getSubqueryAssignments().getExpressions()) { + for (RowExpression expression : node.getSubqueryAssignments().getExpressions()) { Set dependencies = SymbolsExtractor.extractUniqueVariable(expression, types); checkDependencies(inputs, dependencies, "Invalid node. Expression dependencies (%s) not in source plan output (%s)", dependencies, inputs); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/OriginalExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/OriginalExpressionUtils.java index 43ee7d0e0f8c4..c676c269efd50 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/OriginalExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/OriginalExpressionUtils.java @@ -41,7 +41,7 @@ public static RowExpression castToRowExpression(Expression expression) return new OriginalExpression(expression); } - public static Expression asSymbolReference(VariableReferenceExpression variable) + public static SymbolReference asSymbolReference(VariableReferenceExpression variable) { return new SymbolReference(variable.getName()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java index 4d8854a2331e7..bcb65586bf35e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/ProjectNodeUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.Expression; @@ -20,17 +21,30 @@ import java.util.Map; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; + public class ProjectNodeUtils { private ProjectNodeUtils() {} public static boolean isIdentity(ProjectNode projectNode) { - for (Map.Entry entry : projectNode.getAssignments().entrySet()) { - Expression expression = entry.getValue(); + for (Map.Entry entry : projectNode.getAssignments().entrySet()) { + RowExpression value = entry.getValue(); VariableReferenceExpression variable = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(variable.getName()))) { - return false; + // It is used in CostCalculator so currently we need to handle both Expression and RowExpression + // TODO remove handling of Expression once all optimization rule uses RowExpression + if (isExpression(value)) { + Expression expression = castToExpression(value); + if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(variable.getName()))) { + return false; + } + } + else { + if (!(value instanceof VariableReferenceExpression && ((VariableReferenceExpression) value).getName().equals(variable.getName()))) { + return false; + } } } return true; diff --git a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 6a6eae1c1827f..3a3c771d262c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -391,13 +391,13 @@ public Void visitFilter(FilterNode node, Void context) public Void visitProject(ProjectNode node, Void context) { StringBuilder builder = new StringBuilder(); - for (Map.Entry entry : node.getAssignments().entrySet()) { - if ((entry.getValue() instanceof SymbolReference) && - ((SymbolReference) entry.getValue()).getName().equals(entry.getKey().getName())) { + for (Map.Entry entry : node.getAssignments().entrySet()) { + if ((entry.getValue() instanceof VariableReferenceExpression) && + ((VariableReferenceExpression) entry.getValue()).getName().equals(entry.getKey().getName())) { // skip identity assignments continue; } - builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue())); + builder.append(format("%s := %s\\n", entry.getKey(), formatter.formatRowExpression(entry.getValue()))); } printNode(node, "Project", builder.toString(), NODE_COLORS.get(NodeType.PROJECT)); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 5b0cf18bcce7d..8c7ecddff8d08 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -47,7 +47,6 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.optimizations.TranslateExpressions; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -79,6 +78,7 @@ import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.count; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; @@ -805,7 +805,7 @@ private PlanNode project(String id, PlanNode source, VariableReferenceExpression return new ProjectNode( new PlanNodeId(id), source, - Assignments.of(variable, expression)); + assignment(variable, expression)); } private AggregationNode aggregation(String id, PlanNode source) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 9931386590b45..1a808ca94543b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -29,7 +29,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.LimitNode; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -76,6 +75,7 @@ import static com.facebook.presto.sql.ExpressionUtils.and; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.ExpressionUtils.or; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.count; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -236,7 +236,7 @@ public void testProject() equals(AE, BE), equals(BE, CE), lessThan(CE, bigintLiteral(10)))), - Assignments.of(DV, AE, EV, CE)); + assignment(DV, AE, EV, CE)); Expression effectivePredicate = effectivePredicateExtractor.extract(node, types); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index ced3a86ac8905..a290606eea74d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -852,7 +852,7 @@ public void testBroadcastCorrelatedSubqueryAvoidsRemoteExchangeBeforeAggregation // region is unpartitioned, AssignUniqueId should provide satisfying partitioning for count(*) after LEFT JOIN assertPlanWithSession( - "SELECT (SELECT count(*) FROM region r2 WHERE r2.regionkey > r1.regionkey) FROM region r1", + "SELECT (SELECT COUNT(*) FROM region r2 WHERE r2.regionkey > r1.regionkey) FROM region r1", broadcastJoin, false, joinBuildSideWithRemoteExchange, @@ -860,8 +860,8 @@ public void testBroadcastCorrelatedSubqueryAvoidsRemoteExchangeBeforeAggregation // orders is naturally partitioned, AssignUniqueId should not overwrite its natural partitioning assertPlanWithSession( - "SELECT count(count) " + - "FROM (SELECT o1.orderkey orderkey, (SELECT count(*) FROM orders o2 WHERE o2.orderkey > o1.orderkey) count FROM orders o1) " + + "SELECT COUNT(COUNT) " + + "FROM (SELECT o1.orderkey orderkey, (SELECT COUNT(*) FROM orders o2 WHERE o2.orderkey > o1.orderkey) COUNT FROM orders o1) " + "GROUP BY orderkey", broadcastJoin, false, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 86af8009ea5a7..dbcc2fd6065e7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -67,6 +67,7 @@ import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.variable; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; @Test(singleThreaded = true) public class TestTypeValidator @@ -133,8 +134,8 @@ public void testValidProject() Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); Expression expression2 = new Cast(columnC.toSymbolReference(), StandardTypes.BIGINT); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newVariable(expression1, BIGINT), expression1) - .put(symbolAllocator.newVariable(expression2, BIGINT), expression2) + .put(symbolAllocator.newVariable(expression1, BIGINT), castToRowExpression(expression1)) + .put(symbolAllocator.newVariable(expression2, BIGINT), castToRowExpression(expression2)) .build(); PlanNode node = new ProjectNode( newId(), @@ -224,8 +225,8 @@ public void testValidTypeOnlyCoercion() { Expression expression = new Cast(columnB.toSymbolReference(), StandardTypes.BIGINT); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newVariable(expression, BIGINT), expression) - .put(symbolAllocator.newVariable(columnE.toSymbolReference(), VARCHAR), columnE.toSymbolReference()) // implicit coercion from varchar(3) to varchar + .put(symbolAllocator.newVariable(expression, BIGINT), castToRowExpression(expression)) + .put(symbolAllocator.newVariable(columnE.toSymbolReference(), VARCHAR), castToRowExpression(columnE.toSymbolReference())) // implicit coercion from varchar(3) to varchar .build(); PlanNode node = new ProjectNode(newId(), baseTableScan, assignments); @@ -238,8 +239,8 @@ public void testInvalidProject() Expression expression1 = new Cast(columnB.toSymbolReference(), StandardTypes.INTEGER); Expression expression2 = new Cast(columnA.toSymbolReference(), StandardTypes.INTEGER); Assignments assignments = Assignments.builder() - .put(symbolAllocator.newVariable(expression1, BIGINT), expression1) // should be INTEGER - .put(symbolAllocator.newVariable(expression1, INTEGER), expression2) + .put(symbolAllocator.newVariable(expression1, BIGINT), castToRowExpression(expression1)) // should be INTEGER + .put(symbolAllocator.newVariable(expression1, INTEGER), castToRowExpression(expression2)) .build(); PlanNode node = new ProjectNode( newId(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java index e29c89b1921c1..a6f52a3a25850 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionMatcher.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.plan.ApplyNode; @@ -29,6 +30,8 @@ import java.util.stream.Collectors; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; @@ -54,29 +57,38 @@ private Expression expression(String sql) public Optional getAssignedVariable(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) { Optional result = Optional.empty(); - ImmutableList.Builder matchesBuilder = ImmutableList.builder(); - Map assignments = getAssignments(node); + ImmutableList.Builder matchesBuilder = ImmutableList.builder(); + Map assignments = getAssignments(node); if (assignments == null) { return result; } - ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); - - for (Map.Entry assignment : assignments.entrySet()) { - if (verifier.process(assignment.getValue(), expression)) { - result = Optional.of(assignment.getKey()); - matchesBuilder.add(assignment.getValue()); + for (Map.Entry assignment : assignments.entrySet()) { + RowExpression rightValue = assignment.getValue(); + if (isExpression(rightValue)) { + ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases); + if (verifier.process(castToExpression(rightValue), expression)) { + result = Optional.of(assignment.getKey()); + matchesBuilder.add(castToExpression(rightValue)); + } + } + else { + RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session); + if (verifier.process(expression, rightValue)) { + result = Optional.of(assignment.getKey()); + matchesBuilder.add(rightValue); + } } } - List matches = matchesBuilder.build(); + List matches = matchesBuilder.build(); checkState(matches.size() < 2, "Ambiguous expression %s matches multiple assignments", expression, - (matches.stream().map(Expression::toString).collect(Collectors.joining(", ")))); + (matches.stream().map(Object::toString).collect(Collectors.joining(", ")))); return result; } - private static Map getAssignments(PlanNode node) + private static Map getAssignments(PlanNode node) { if (node instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) node; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java index f14eb2de071bb..ac5610253ffda 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchingVisitor.java @@ -24,12 +24,13 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.ProjectNode; -import com.facebook.presto.sql.tree.SymbolReference; import java.util.List; import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -66,7 +67,7 @@ public MatchResult visitExchange(ExchangeNode node, PlanMatchPattern pattern) for (List inputs : allInputs) { Assignments.Builder assignments = Assignments.builder(); for (int i = 0; i < inputs.size(); ++i) { - assignments.put(outputs.get(i), new SymbolReference(inputs.get(i).getName())); + assignments.put(outputs.get(i), castToRowExpression(asSymbolReference(inputs.get(i)))); } newAliases = newAliases.updateAssignments(assignments.build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java index f3463ece9cca3..5542fb7293679 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/SymbolAliases.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.planner.assertions; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; @@ -24,6 +24,9 @@ import java.util.Map; import java.util.Optional; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static java.lang.String.format; @@ -97,9 +100,16 @@ private static String toKey(String alias) private Map getUpdatedAssignments(Assignments assignments) { ImmutableMap.Builder mapUpdate = ImmutableMap.builder(); - for (Map.Entry assignment : assignments.getMap().entrySet()) { + for (Map.Entry assignment : assignments.getMap().entrySet()) { for (Map.Entry existingAlias : map.entrySet()) { - if (assignment.getValue().equals(existingAlias.getValue())) { + RowExpression expression = assignment.getValue(); + if (isExpression(expression) && castToExpression(expression).equals(existingAlias.getValue())) { + // Simple symbol rename + mapUpdate.put(existingAlias.getKey(), asSymbolReference(assignment.getKey())); + } + else if (!isExpression(expression) && + (expression instanceof VariableReferenceExpression) && + ((VariableReferenceExpression) expression).getName().equals(existingAlias.getValue().getName())) { // Simple symbol rename mapUpdate.put(existingAlias.getKey(), new SymbolReference(assignment.getKey().getName())); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 962299622b729..ddb03e351d59a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -20,7 +20,6 @@ import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.ProjectNode; @@ -42,6 +41,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static com.google.common.base.Preconditions.checkArgument; @@ -250,7 +250,7 @@ private PlanNode projectNode(PlanNode source, VariableReferenceExpression variab return new ProjectNode( idAllocator.getNextId(), source, - Assignments.of(variable, expression)); + assignment(variable, expression)); } private VariableReferenceExpression variable(String name) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java index 08c0bc72463ac..2340fc8f1ba96 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestExpressionRewriteRuleSet.java @@ -17,7 +17,6 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; @@ -35,6 +34,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; public class TestExpressionRewriteRuleSet @@ -58,7 +58,7 @@ public void testProjectionExpressionRewrite() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.variable("y"), PlanBuilder.expression("x IS NOT NULL")), + assignment(p.variable("y"), PlanBuilder.expression("x IS NOT NULL")), p.values(p.variable("x")))) .matches( project(ImmutableMap.of("y", expression("0")), values("x"))); @@ -69,7 +69,7 @@ public void testProjectionExpressionNotRewritten() { tester().assertThat(zeroRewriter.projectExpressionRewrite()) .on(p -> p.project( - Assignments.of(p.variable("y"), PlanBuilder.expression("0")), + assignment(p.variable("y"), PlanBuilder.expression("0")), p.values(p.variable("x")))) .doesNotFire(); } @@ -162,7 +162,7 @@ public void testApplyExpressionRewrite() { tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( - Assignments.of( + assignment( p.variable("a", BIGINT), new InPredicate( new LongLiteral("1"), @@ -185,7 +185,7 @@ public void testApplyExpressionNotRewritten() { tester().assertThat(applyRewriter.applyExpressionRewrite()) .on(p -> p.apply( - Assignments.of( + assignment( p.variable("a", BIGINT), new InPredicate( new LongLiteral("0"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java index f31e91a1d18f8..3db7ef0441ffc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestInlineProjections.java @@ -22,6 +22,8 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.castToRowExpression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; @@ -35,19 +37,19 @@ public void test() .on(p -> p.project( Assignments.builder() - .put(p.variable("identity"), expression("symbol")) // identity - .put(p.variable("multi_complex_1"), expression("complex + 1")) // complex expression referenced multiple times - .put(p.variable("multi_complex_2"), expression("complex + 2")) // complex expression referenced multiple times - .put(p.variable("multi_literal_1"), expression("literal + 1")) // literal referenced multiple times - .put(p.variable("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times - .put(p.variable("single_complex"), expression("complex_2 + 2")) // complex expression reference only once - .put(p.variable("try"), expression("try(complex / literal)")) + .put(p.variable("identity"), castToRowExpression("symbol")) // identity + .put(p.variable("multi_complex_1"), castToRowExpression("complex + 1")) // complex expression referenced multiple times + .put(p.variable("multi_complex_2"), castToRowExpression("complex + 2")) // complex expression referenced multiple times + .put(p.variable("multi_literal_1"), castToRowExpression("literal + 1")) // literal referenced multiple times + .put(p.variable("multi_literal_2"), castToRowExpression("literal + 2")) // literal referenced multiple times + .put(p.variable("single_complex"), castToRowExpression("complex_2 + 2")) // complex expression reference only once + .put(p.variable("try"), castToRowExpression("try(complex / literal)")) .build(), p.project(Assignments.builder() - .put(p.variable("symbol"), expression("x")) - .put(p.variable("complex"), expression("x * 2")) - .put(p.variable("literal"), expression("1")) - .put(p.variable("complex_2"), expression("x - 1")) + .put(p.variable("symbol"), castToRowExpression("x")) + .put(p.variable("complex"), castToRowExpression("x * 2")) + .put(p.variable("literal"), castToRowExpression("1")) + .put(p.variable("complex_2"), castToRowExpression("x - 1")) .build(), p.values(p.variable("x"))))) .matches( @@ -74,7 +76,7 @@ public void testIdentityProjections() tester().assertThat(new InlineProjections()) .on(p -> p.project( - Assignments.of(p.variable("output"), expression("value")), + assignment(p.variable("output"), expression("value")), p.project( identityAssignmentsAsSymbolReferences(p.variable("value")), p.values(p.variable("value"))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index f208afb3696c8..f20313444f06c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -38,6 +38,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.castToRowExpression; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; @@ -182,7 +183,7 @@ public void testIntermediateProjectNodes() ImmutableMap.of(p.variable(p.symbol("lagOutput")), newWindowNodeFunction("lag", LAG_FUNCTION_HANDLE, "a", "one")), p.project( Assignments.builder() - .put(p.variable("one"), expression("CAST(1 AS bigint)")) + .put(p.variable("one"), castToRowExpression("CAST(1 AS bigint)")) .putAll(identitiesAsSymbolReferences(ImmutableList.of(p.variable("a"), p.variable("avgOutput")))) .build(), p.project( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java index 250cc596be80b..3b3292b9300fc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneCountAggregationOverScalar.java @@ -18,9 +18,7 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.plan.AggregationNode; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SymbolReference; @@ -36,6 +34,8 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; @@ -91,7 +91,7 @@ public void testFiresOnCountAggregateOverValues() .globalGrouping() .source(p.values( ImmutableList.of(p.variable(p.symbol("orderkey"))), - ImmutableList.of(PlanBuilder.constantExpressions(BIGINT, 1)))))) + ImmutableList.of(constantExpressions(BIGINT, 1)))))) .matches(values(ImmutableMap.of("count_1", 0))); } @@ -147,7 +147,7 @@ public void testDoesNotFireOnNestedNonCountAggregate() .globalGrouping() .source( p.project( - Assignments.of(totalPriceVariable, totalPrice.toSymbolReference()), + assignment(totalPriceVariable, totalPrice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java index 2ff7954f7f742..15df357def151 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMarkDistinctColumns.java @@ -15,8 +15,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -25,7 +23,9 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; public class TestPruneMarkDistinctColumns extends BaseRuleTest @@ -41,7 +41,7 @@ public void testMarkerSymbolNotReferenced() VariableReferenceExpression mark = p.variable("mark"); VariableReferenceExpression unused = p.variable("unused"); return p.project( - Assignments.of(key2, new SymbolReference(key.getName())), + assignment(key2, asSymbolReference(key)), p.markDistinct(mark, ImmutableList.of(key), p.values(key, unused))); }) .matches( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java index b8a7a74c10142..18f8cde1f47d9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableScanColumns.java @@ -19,7 +19,6 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.facebook.presto.testing.TestingTransactionHandle; import com.facebook.presto.tpch.TpchColumnHandle; @@ -34,6 +33,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; @@ -51,7 +51,7 @@ public void testNotAllOutputsReferenced() Symbol totalprice = p.symbol("totalprice", DOUBLE); VariableReferenceExpression totalpriceVariable = new VariableReferenceExpression(totalprice.getName(), DOUBLE); return p.project( - Assignments.of(p.variable("x"), totalprice.toSymbolReference()), + assignment(p.variable("x"), totalprice.toSymbolReference()), p.tableScan( new TableHandle( new ConnectorId("local"), @@ -77,7 +77,7 @@ public void testAllOutputsReferenced() Symbol x = p.symbol("x"); VariableReferenceExpression xv = p.variable(x); return p.project( - Assignments.of(p.variable("y"), expression("x")), + assignment(p.variable("y"), expression("x")), p.tableScan( ImmutableList.of(xv), ImmutableMap.of(p.variable(p.symbol("x")), new TestingColumnHandle("x")))); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java index 1732c8a3e084c..df6e6ead59f40 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneValuesColumns.java @@ -15,7 +15,6 @@ import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -23,6 +22,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; @@ -35,7 +35,7 @@ public void testNotAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.variable("y"), expression("x")), + assignment(p.variable("y"), expression("x")), p.values( ImmutableList.of(p.variable("unused"), p.variable("x")), ImmutableList.of( @@ -57,7 +57,7 @@ public void testAllOutputsReferenced() tester().assertThat(new PruneValuesColumns()) .on(p -> p.project( - Assignments.of(p.variable("y"), expression("x")), + assignment(p.variable("y"), expression("x")), p.values(p.variable("x")))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java index c913a8d21e683..67d606b4ddfe5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.java @@ -16,7 +16,6 @@ import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -38,7 +37,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; -import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; @@ -202,9 +201,7 @@ public void testDoesNotFireWhenNotDistinct() .source( p.join( JoinNode.Type.LEFT, - p.project(Assignments.builder() - .put(identityAsSymbolReference(p.variable("COL1", BIGINT))) - .build(), + p.project(identityAssignmentsAsSymbolReferences(p.variable("COL1", BIGINT)), p.aggregation(builder -> builder.singleGroupingSet(p.variable("COL1"), p.variable("unused")) .source( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java index 00176923766eb..19583c3fb4847 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -15,8 +15,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; -import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -24,6 +22,8 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; public class TestPushLimitThroughProject @@ -37,7 +37,7 @@ public void testPushdownLimitNonIdentityProjection() VariableReferenceExpression a = p.variable("a"); return p.limit(1, p.project( - Assignments.of(a, TRUE_LITERAL), + assignment(a, TRUE_LITERAL), p.values())); }) .matches( @@ -54,7 +54,7 @@ public void testDoesntPushdownLimitThroughIdentityProjection() VariableReferenceExpression a = p.variable("a"); return p.limit(1, p.project( - Assignments.of(a, new SymbolReference(a.getName())), + assignment(a, asSymbolReference(a)), p.values(a))); }).doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index a4bcc2e7745c5..d00acd6191a67 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -30,8 +30,11 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; @@ -44,7 +47,7 @@ public void testDoesNotFireNoExchange() tester().assertThat(new PushProjectionThroughExchange()) .on(p -> p.project( - Assignments.of(p.variable("x"), new LongLiteral("3")), + assignment(p.variable("x"), new LongLiteral("3")), p.values(p.variable("a")))) .doesNotFire(); } @@ -60,8 +63,8 @@ public void testDoesNotFireNarrowingProjection() return p.project( Assignments.builder() - .put(a, new SymbolReference(a.getName())) - .put(b, new SymbolReference(b.getName())) + .put(a, castToRowExpression(asSymbolReference(a))) + .put(b, castToRowExpression(asSymbolReference(b))) .build(), p.exchange(e -> e .addSource(p.values(a, b, c)) @@ -82,7 +85,7 @@ public void testSimpleMultipleInputs() VariableReferenceExpression c2 = p.variable("c2"); VariableReferenceExpression x = p.variable("x"); return p.project( - Assignments.of( + assignment( x, new LongLiteral("3"), c2, new SymbolReference("c")), p.exchange(e -> e @@ -120,9 +123,9 @@ public void testPartitioningColumnAndHashWithoutIdentityMappingInProjection() VariableReferenceExpression hTimes5 = p.variable("h_times_5"); return p.project( Assignments.builder() - .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) - .put(bTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5"))) - .put(hTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5"))) + .put(aTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5")))) + .put(bTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5")))) + .put(hTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5")))) .build(), p.exchange(e -> e .addSource( @@ -164,9 +167,9 @@ public void testOrderingColumnsArePreserved() OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(sortVariable), ImmutableMap.of(sortVariable, SortOrder.ASC_NULLS_FIRST)); return p.project( Assignments.builder() - .put(aTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5"))) - .put(bTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5"))) - .put(hTimes5, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5"))) + .put(aTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("a"), new LongLiteral("5")))) + .put(bTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("b"), new LongLiteral("5")))) + .put(hTimes5, castToRowExpression(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference("h"), new LongLiteral("5")))) .build(), p.exchange(e -> e .addSource( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index e0677da3253a7..355032a178d27 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -15,7 +15,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.SymbolReference; @@ -28,6 +27,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.union; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; public class TestPushProjectionThroughUnion extends BaseRuleTest @@ -38,7 +38,7 @@ public void testDoesNotFire() tester().assertThat(new PushProjectionThroughUnion()) .on(p -> p.project( - Assignments.of(p.variable("x"), new LongLiteral("3")), + assignment(p.variable("x"), new LongLiteral("3")), p.values(p.variable("a")))) .doesNotFire(); } @@ -52,7 +52,9 @@ public void test() VariableReferenceExpression b = p.variable("b"); VariableReferenceExpression c = p.variable("c"); return p.project( - Assignments.of(p.variable("c_times_3"), new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference(c.getName()), new LongLiteral("3"))), + assignment( + p.variable("c_times_3"), + new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new SymbolReference(c.getName()), new LongLiteral("3"))), p.union( ImmutableListMultimap.builder() .put(c, a) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java index c35e841c47e3d..7f8ce39c68c65 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveUnreferencedScalarApplyNodes.java @@ -20,6 +20,7 @@ import org.testng.annotations.Test; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; public class TestRemoveUnreferencedScalarApplyNodes extends BaseRuleTest @@ -29,7 +30,7 @@ public void testDoesNotFire() { tester().assertThat(new RemoveUnreferencedScalarApplyNodes()) .on(p -> p.apply( - Assignments.of(p.variable("z"), p.expression("x IN (y)")), + assignment(p.variable("z"), p.expression("x IN (y)")), ImmutableList.of(), p.values(p.variable("x")), p.values(p.variable("y")))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java index 009f843e9f9a7..03981e7a0e773 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarAggregationToJoin.java @@ -15,7 +15,6 @@ import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.planner.plan.JoinNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -29,6 +28,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; public class TestTransformCorrelatedScalarAggregationToJoin extends BaseRuleTest @@ -104,9 +104,9 @@ public void rewritesOnSubqueryWithProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionManager())) .on(p -> p.lateral( - ImmutableList.of(p.variable(p.symbol("corr"))), + ImmutableList.of(p.variable("corr")), p.values(p.variable("corr")), - p.project(Assignments.of(p.variable("expr"), p.expression("sum + 1")), + p.project(assignment(p.variable("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab .source(p.values(p.variable("a"), p.variable("b"))) .addAggregation(p.variable(p.symbol("sum")), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java index eac94f7ec1ea6..e0f6fe8b1e0e6 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedScalarSubquery.java @@ -19,7 +19,6 @@ import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -44,6 +43,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -123,7 +123,7 @@ public void rewritesOnSubqueryWithProjection() p.values(p.variable("corr")), p.enforceSingleRow( p.project( - Assignments.of(p.variable("a2"), p.expression("a * 2")), + assignment(p.variable("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.variable("a")), TWO_ROWS)))))) @@ -152,10 +152,10 @@ public void rewritesOnSubqueryWithProjectionOnTopEnforceSingleNode() ImmutableList.of(p.variable("corr")), p.values(p.variable("corr")), p.project( - Assignments.of(p.variable("a3"), p.expression("a2 + 1")), + assignment(p.variable("a3"), p.expression("a2 + 1")), p.enforceSingleRow( p.project( - Assignments.of(p.variable("a2"), p.expression("a * 2")), + assignment(p.variable("a2"), p.expression("a * 2")), p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.variable("a")), TWO_ROWS))))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java index 18de456e0d6ec..6a9c9ed4b7155 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformCorrelatedSingleRowSubqueryToProject.java @@ -17,7 +17,6 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; -import com.facebook.presto.sql.planner.plan.Assignments; import com.facebook.presto.testing.TestingTransactionHandle; import com.facebook.presto.tpch.TpchColumnHandle; import com.facebook.presto.tpch.TpchTableHandle; @@ -30,6 +29,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; @@ -60,7 +60,7 @@ public void testRewrite() ImmutableMap.of(p.variable(p.symbol("l_nationkey")), new TpchColumnHandle("nationkey", BIGINT))), p.project( - Assignments.of(p.variable("l_expr2"), expression("l_nationkey + 1")), + assignment(p.variable("l_expr2"), expression("l_nationkey + 1")), p.values( ImmutableList.of(), ImmutableList.of(ImmutableList.of()))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java index a08394a751fa2..7ade6679342ed 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformExistsApplyToLateralJoin.java @@ -29,6 +29,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; public class TestTransformExistsApplyToLateralJoin @@ -56,7 +57,7 @@ public void testRewrite() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + assignment(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), ImmutableList.of(), p.values(), p.values())) @@ -75,8 +76,8 @@ public void testRewritesToLimit() tester().assertThat(new TransformExistsApplyToLateralNode(tester().getMetadata().getFunctionManager())) .on(p -> p.apply( - Assignments.of(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), - ImmutableList.of(p.variable(p.symbol("corr"))), + assignment(p.variable("b", BOOLEAN), expression("EXISTS(SELECT TRUE)")), + ImmutableList.of(p.variable("corr")), p.values(p.variable("corr")), p.project(Assignments.of(), p.filter( diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java index f4c1a3663367d..73e277482bfcd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformUncorrelatedInPredicateSubqueryToSemiJoin.java @@ -24,6 +24,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static java.util.Collections.emptyList; public class TestTransformUncorrelatedInPredicateSubqueryToSemiJoin @@ -46,7 +47,7 @@ public void testDoesNotFireOnNonInPredicateSubquery() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of(p.variable("x"), new ExistsPredicate(new LongLiteral("1"))), + assignment(p.variable("x"), new ExistsPredicate(new LongLiteral("1"))), emptyList(), p.values(), p.values())) @@ -58,7 +59,7 @@ public void testFiresForInPredicate() { tester().assertThat(new TransformUncorrelatedInPredicateSubqueryToSemiJoin()) .on(p -> p.apply( - Assignments.of( + assignment( p.variable("x"), new InPredicate( new SymbolReference("y"), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 3a0b3afaf60d2..3df1f7014c3b0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -101,7 +101,6 @@ import static com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil.verifySubquerySupported; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.constantNull; -import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.facebook.presto.util.MoreLists.nElements; @@ -123,6 +122,16 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) this.metadata = metadata; } + public static Assignments assignment(VariableReferenceExpression variable, Expression expression) + { + return Assignments.builder().put(variable, OriginalExpressionUtils.castToRowExpression(expression)).build(); + } + + public static Assignments assignment(VariableReferenceExpression variable1, Expression expression1, VariableReferenceExpression variable2, Expression expression2) + { + return Assignments.builder().put(variable1, OriginalExpressionUtils.castToRowExpression(expression1)).put(variable2, OriginalExpressionUtils.castToRowExpression(expression2)).build(); + } + public OutputNode output(List columnNames, List variables, PlanNode source) { return new OutputNode( @@ -247,7 +256,7 @@ public MarkDistinctNode markDistinct(VariableReferenceExpression markerVariable, public FilterNode filter(Expression predicate, PlanNode source) { - return new FilterNode(idAllocator.getNextId(), source, castToRowExpression(predicate)); + return new FilterNode(idAllocator.getNextId(), source, OriginalExpressionUtils.castToRowExpression(predicate)); } public FilterNode filter(RowExpression predicate, PlanNode source) @@ -799,6 +808,11 @@ public static Expression expression(String sql) return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql)); } + public static RowExpression castToRowExpression(String sql) + { + return OriginalExpressionUtils.castToRowExpression(ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql))); + } + public static Expression expression(String sql, ParsingOptions options) { return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql, options)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java index 8d62faf152cee..d8d48679d5e6e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java @@ -24,7 +24,7 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; -import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static com.facebook.presto.sql.relational.Expressions.variable; public class TestRuleTester { @@ -35,7 +35,7 @@ public void testReportWrongMatch() tester.assertThat(new DummyReplaceNodeRule()) .on(p -> p.project( - Assignments.of(p.variable("y"), expression("x")), + Assignments.of(p.variable("y"), variable("x", BIGINT)), p.values( ImmutableList.of(p.variable(p.symbol("x"))), ImmutableList.of(constantExpressions(BIGINT, 1))))) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java index 25cf226cc8f32..7278b6560dd43 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestAssingments.java @@ -18,12 +18,13 @@ import org.testng.annotations.Test; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; import static org.testng.Assert.assertTrue; public class TestAssingments { - private final Assignments assignments = Assignments.of(new VariableReferenceExpression("test", BIGINT), TRUE_LITERAL); + private final Assignments assignments = assignment(new VariableReferenceExpression("test", BIGINT), TRUE_LITERAL); @Test public void testOutputsImmutable() From 09c5b4567403f562e0c3128f520418deefb82f8e Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:49 -0700 Subject: [PATCH 10/11] Remove OriginalExpression before plan assertion --- .../sql/planner/assertions/BasePlanTest.java | 24 +++++++++++++++++-- .../sql/planner/assertions/PlanAssert.java | 19 ++++++++++----- .../iterative/rule/test/RuleAssert.java | 14 +++++++++-- .../iterative/rule/test/TestRuleTester.java | 5 +++- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 3e49f49838497..4f6c7ba0c58fd 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs; +import com.facebook.presto.sql.planner.optimizations.TranslateExpressions; import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.tpch.TpchConnectorFactory; @@ -133,7 +134,14 @@ protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPatte protected void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern, List optimizers) { queryRunner.inTransaction(transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, optimizers, stage, WarningCollector.NOOP); + Plan actualPlan = queryRunner.createPlan( + transactionSession, + sql, + ImmutableList.builder() + .addAll(optimizers) + .add(translateExpressions()).build(), + stage, + WarningCollector.NOOP); PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), actualPlan, pattern); return null; }); @@ -158,11 +166,23 @@ protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMat new RuleStatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getCostCalculator(), - ImmutableSet.of(new RemoveRedundantIdentityProjections()))); + ImmutableSet.of(new RemoveRedundantIdentityProjections())), + translateExpressions()); assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED, pattern, optimizers); } + private PlanOptimizer translateExpressions() + { + return new IterativeOptimizer( + new RuleStatsRecorder(), + queryRunner.getStatsCalculator(), + queryRunner.getCostCalculator(), + new ImmutableSet.Builder() + .addAll(new TranslateExpressions(queryRunner.getMetadata(), queryRunner.getSqlParser()).rules()) + .build()); + } + protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean forceSingleNode, PlanMatchPattern pattern) { queryRunner.inTransaction(session, transactionSession -> { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index 021eba11750f0..445f8a7bb0518 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -23,6 +23,8 @@ import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.iterative.Lookup; +import java.util.function.Function; + import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.iterative.Plans.resolveGroupReferences; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textLogicalPlan; @@ -34,22 +36,27 @@ private PlanAssert() {} public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, PlanMatchPattern pattern) { - assertPlan(session, metadata, statsCalculator, actual, noLookup(), pattern); + assertPlan(session, metadata, statsCalculator, actual, pattern, Function.identity()); + } + + public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, PlanMatchPattern pattern, Function planSanitizer) + { + assertPlan(session, metadata, statsCalculator, actual, noLookup(), pattern, planSanitizer); } - public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern) + public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern, Function planSanitizer) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, actual.getTypes()); - assertPlan(session, metadata, statsProvider, actual, lookup, pattern); + assertPlan(session, metadata, statsProvider, actual, lookup, pattern, planSanitizer); } - public static void assertPlan(Session session, Metadata metadata, StatsProvider statsProvider, Plan actual, Lookup lookup, PlanMatchPattern pattern) + public static void assertPlan(Session session, Metadata metadata, StatsProvider statsProvider, Plan actual, Lookup lookup, PlanMatchPattern pattern, Function planSanitizer) { MatchResult matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata, statsProvider, lookup), pattern); if (!matches.isMatch()) { - String formattedPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata.getFunctionManager(), StatsAndCosts.empty(), session, 0); + String formattedPlan = textLogicalPlan(planSanitizer.apply(actual.getRoot()), actual.getTypes(), metadata.getFunctionManager(), StatsAndCosts.empty(), session, 0); PlanNode resolvedPlan = resolveGroupReferences(actual.getRoot(), lookup); - String resolvedFormattedPlan = textLogicalPlan(resolvedPlan, actual.getTypes(), metadata.getFunctionManager(), StatsAndCosts.empty(), session, 0); + String resolvedFormattedPlan = textLogicalPlan(planSanitizer.apply(resolvedPlan), actual.getTypes(), metadata.getFunctionManager(), StatsAndCosts.empty(), session, 0); throw new AssertionError(format( "Plan does not match, expected [\n\n%s\n] but found [\n\n%s\n] which resolves to [\n\n%s\n]", pattern, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 24729587c698a..e62448ddddcf2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -29,14 +29,18 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Memo; import com.facebook.presto.sql.planner.iterative.PlanNodeMatcher; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.TranslateExpressions; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -151,7 +155,7 @@ public void matches(PlanMatchPattern pattern) } inTransaction(session -> { - assertPlan(session, metadata, ruleApplication.statsProvider, new Plan(actual, types, StatsAndCosts.empty()), ruleApplication.lookup, pattern); + assertPlan(session, metadata, ruleApplication.statsProvider, new Plan(actual, types, StatsAndCosts.empty()), ruleApplication.lookup, pattern, planNode -> translateExpressions(planNode, types)); return null; }); } @@ -187,7 +191,7 @@ private String formatPlan(PlanNode plan, TypeProvider types) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session, types); - return inTransaction(session -> textLogicalPlan(plan, types, metadata.getFunctionManager(), StatsAndCosts.create(plan, statsProvider, costProvider), session, 2, false)); + return inTransaction(session -> textLogicalPlan(translateExpressions(plan, types), types, metadata.getFunctionManager(), StatsAndCosts.create(plan, statsProvider, costProvider), session, 2, false)); } private T inTransaction(Function transactionSessionConsumer) @@ -201,6 +205,12 @@ private T inTransaction(Function transactionSessionConsumer) }); } + private PlanNode translateExpressions(PlanNode node, TypeProvider typeProvider) + { + IterativeOptimizer optimizer = new IterativeOptimizer(new RuleStatsRecorder(), statsCalculator, costCalculator, new TranslateExpressions(metadata, new SqlParser()).rules()); + return optimizer.optimize(node, session, typeProvider, new SymbolAllocator(typeProvider.allTypes()), idAllocator, WarningCollector.NOOP); + } + private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, SymbolAllocator symbolAllocator, Memo memo, Lookup lookup, Session session) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java index d8d48679d5e6e..317c26a9a8d33 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TestRuleTester.java @@ -25,6 +25,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; import static com.facebook.presto.sql.relational.Expressions.variable; +import static com.google.common.collect.ImmutableList.toImmutableList; public class TestRuleTester { @@ -56,7 +57,9 @@ public Pattern getPattern() @Override public Result apply(PlanNode node, Captures captures, Context context) { - return Result.ofPlanNode(node.replaceChildren(node.getSources())); + return Result.ofPlanNode(node.replaceChildren(node.getSources().stream() + .map(context.getLookup()::resolve) + .collect(toImmutableList()))); } } } From 2032ae48ef6f301d594d2601c863300c0f824a4e Mon Sep 17 00:00:00 2001 From: Yi He Date: Mon, 17 Jun 2019 17:24:51 -0700 Subject: [PATCH 11/11] Remove unused assertPlan --- .../presto/sql/planner/assertions/PlanAssert.java | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java index 445f8a7bb0518..39cf8eefb6e13 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanAssert.java @@ -35,19 +35,9 @@ public final class PlanAssert private PlanAssert() {} public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, PlanMatchPattern pattern) - { - assertPlan(session, metadata, statsCalculator, actual, pattern, Function.identity()); - } - - public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, PlanMatchPattern pattern, Function planSanitizer) - { - assertPlan(session, metadata, statsCalculator, actual, noLookup(), pattern, planSanitizer); - } - - public static void assertPlan(Session session, Metadata metadata, StatsCalculator statsCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern, Function planSanitizer) { StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, actual.getTypes()); - assertPlan(session, metadata, statsProvider, actual, lookup, pattern, planSanitizer); + assertPlan(session, metadata, statsProvider, actual, noLookup(), pattern, Function.identity()); } public static void assertPlan(Session session, Metadata metadata, StatsProvider statsProvider, Plan actual, Lookup lookup, PlanMatchPattern pattern, Function planSanitizer)