diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java index 15090721f42a..acbc6226ff2e 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java @@ -16,7 +16,6 @@ import io.trino.Session; import io.trino.matching.Pattern; import io.trino.sql.planner.TypeProvider; -import io.trino.sql.planner.iterative.GroupReference; import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.FilterNode; @@ -25,7 +24,6 @@ import java.util.Optional; -import static com.google.common.collect.MoreCollectors.onlyElement; import static io.trino.SystemSessionProperties.isNonEstimatablePredicateApproximationEnabled; import static io.trino.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static io.trino.sql.planner.plan.Patterns.filter; @@ -61,7 +59,7 @@ protected Optional doCalculate(FilterNode node, StatsProv if (!isNonEstimatablePredicateApproximationEnabled(session)) { return Optional.empty(); } - PlanNode nodeSource = resolveGroup(lookup, node.getSource()); + PlanNode nodeSource = lookup.resolve(node.getSource()); AggregationNode aggregationNode; // TODO match the required source nodes through separate patterns when // ComposableStatsCalculator allows patterns other than TypeOfPattern @@ -70,7 +68,7 @@ protected Optional doCalculate(FilterNode node, StatsProv if (!projectNode.isIdentity()) { return Optional.empty(); } - PlanNode projectNodeSource = resolveGroup(lookup, projectNode.getSource()); + PlanNode projectNodeSource = lookup.resolve(projectNode.getSource()); if (!(projectNodeSource instanceof AggregationNode)) { return Optional.empty(); } @@ -99,12 +97,4 @@ private Optional calculate(FilterNode filterNode, Aggrega } return Optional.of(filteredStats); } - - private static PlanNode resolveGroup(Lookup lookup, PlanNode node) - { - if (node instanceof GroupReference) { - return lookup.resolveGroup(node).collect(onlyElement()); - } - return node; - } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java index d44fa29972ad..029fecbb74f9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -27,7 +27,6 @@ import io.trino.matching.Pattern; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.sql.planner.TypeProvider; -import io.trino.sql.planner.iterative.GroupReference; import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.PlanNodeSearcher; @@ -145,12 +144,7 @@ private static double getFirstKnownOutputSizeInBytes(PlanNode node, Context cont static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider, TypeProvider typeProvider) { return Stream.of(node) - .flatMap(planNode -> { - if (planNode instanceof GroupReference) { - return lookup.resolveGroup(node); - } - return Stream.of(planNode); - }) + .map(lookup::resolve) .mapToDouble(resolvedNode -> { double outputSizeInBytes = statsProvider.getStats(resolvedNode).getOutputSizeInBytes( resolvedNode.getOutputSymbols(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java index ba8e1306ec83..53c4c4a3aa00 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java @@ -26,9 +26,8 @@ import io.trino.sql.planner.plan.UnionNode; import java.util.List; -import java.util.stream.Collectors; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; import static io.trino.sql.planner.plan.Patterns.aggregation; @@ -51,9 +50,9 @@ public Result apply(AggregationNode node, Captures captures, Context context) DistinctAggregationRewriter rewriter = new DistinctAggregationRewriter(lookup); List newSources = node.getSources().stream() - .flatMap(lookup::resolveGroup) + .map(lookup::resolve) .map(source -> source.accept(rewriter, true)) - .collect(Collectors.toList()); + .collect(toImmutableList()); if (rewriter.isRewritten()) { return Result.ofPlanNode(replaceChildren(node, newSources)); @@ -86,8 +85,9 @@ public boolean isRewritten() private PlanNode rewriteChildren(PlanNode node, Boolean context) { List newSources = node.getSources().stream() - .flatMap(lookup::resolveGroup) - .map(source -> source.accept(this, context)).collect(Collectors.toList()); + .map(lookup::resolve) + .map(source -> source.accept(this, context)) + .collect(toImmutableList()); return replaceChildren(node, newSources); } @@ -128,8 +128,7 @@ public PlanNode visitAggregation(AggregationNode node, Boolean context) { boolean distinct = isDistinctOperator(node); - PlanNode rewrittenNode = getOnlyElement(lookup.resolveGroup(node.getSource()) - .map(source -> source.accept(this, distinct)).collect(Collectors.toList())); + PlanNode rewrittenNode = lookup.resolve(node.getSource()).accept(this, distinct); if (context && distinct) { this.rewritten = true; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java index 6a4e8177b420..262cf4879d54 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationMerge.java @@ -29,9 +29,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; class SetOperationMerge { @@ -56,8 +56,8 @@ public Optional mergeFirstSource() { Lookup lookup = context.getLookup(); List sources = node.getSources().stream() - .flatMap(lookup::resolveGroup) - .collect(Collectors.toList()); + .map(lookup::resolve) + .collect(toImmutableList()); PlanNode child = sources.get(0); @@ -101,8 +101,8 @@ public Optional merge() Lookup lookup = context.getLookup(); List sources = node.getSources().stream() - .flatMap(lookup::resolveGroup) - .collect(Collectors.toList()); + .map(lookup::resolve) + .collect(toImmutableList()); ImmutableListMultimap.Builder newMappingsBuilder = ImmutableListMultimap.builder(); boolean resultIsDistinct = false;