diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index ef11d6fcf67a..595236e44afa 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -50,7 +50,6 @@ import io.trino.sql.analyzer.Scope; import io.trino.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; import io.trino.sql.planner.optimizations.PlanOptimizer; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.DeleteNode; import io.trino.sql.planner.plan.ExplainAnalyzeNode; @@ -134,6 +133,7 @@ import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.QueryPlanner.visibleFields; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.TableWriterNode.CreateReference; import static io.trino.sql.planner.plan.TableWriterNode.InsertReference; @@ -362,15 +362,11 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme PlanNode planNode = new StatisticsWriterNode( idAllocator.getNextId(), - new AggregationNode( + singleAggregation( idAllocator.getNextId(), TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.buildOrThrow(), false, Optional.empty()), statisticAggregations.getAggregations(), - singleGroupingSet(groupingSymbols), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()), + singleGroupingSet(groupingSymbols)), new StatisticsWriterNode.WriteStatisticsReference(targetTable), symbolAllocator.newSymbol("rows", BIGINT), tableStatisticsMetadata.getTableStatistics().contains(ROW_COUNT), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java index 01fb88aef0e5..8ea0a615799f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java @@ -79,7 +79,10 @@ protected PlanNode visitPlan(PlanNode node, RewriteContext context) @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { - return new AggregationNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()); + return AggregationNode.builderFrom(node) + .setId(idAllocator.getNextId()) + .setSource(context.rewrite(node.getSource())) + .build(); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 494cc69adc01..a207904874e6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -128,6 +128,7 @@ import static io.trino.sql.planner.PlanBuilder.newPlanBuilder; import static io.trino.sql.planner.ScopeAware.scopeAwareKey; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -333,15 +334,11 @@ public RelationPlan planExpand(Query query) PlanNode result = new UnionNode(idAllocator.getNextId(), nodesToUnion, unionSymbolMapping.build(), unionOutputSymbols); if (union.isDistinct()) { - result = new AggregationNode( + result = singleAggregation( idAllocator.getNextId(), result, ImmutableMap.of(), - singleGroupingSet(result.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(result.getOutputSymbols())); } return new RelationPlan(result, anchorPlan.getScope(), unionOutputSymbols, outerContext); @@ -1654,15 +1651,11 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node, List< .collect(Collectors.toList()); return subPlan.withNewRoot( - new AggregationNode( + singleAggregation( idAllocator.getNextId(), subPlan.getRoot(), ImmutableMap.of(), - singleGroupingSet(symbols), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty())); + singleGroupingSet(symbols))); } return subPlan; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 29dec22e387a..4ff8db098a76 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -33,7 +33,6 @@ import io.trino.sql.analyzer.Field; import io.trino.sql.analyzer.RelationType; import io.trino.sql.analyzer.Scope; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.CorrelatedJoinNode; import io.trino.sql.planner.plan.ExceptNode; @@ -122,6 +121,7 @@ import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions; import static io.trino.sql.planner.QueryPlanner.planWindowSpecification; import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.Join.Type.CROSS; @@ -1160,14 +1160,10 @@ private SetOperationPlan process(SetOperation node) private PlanNode distinct(PlanNode node) { - return new AggregationNode(idAllocator.getNextId(), + return singleAggregation(idAllocator.getNextId(), node, ImmutableMap.of(), - singleGroupingSet(node.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(node.getOutputSymbols())); } private static final class SetOperationPlan diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java index dde4c8b4b90d..6b787648dc01 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -153,15 +153,12 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI { verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation"); ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation); - return new AggregationNode( - idAllocator.getNextId(), - gatheringExchange, - outputsAsInputs(aggregation.getAggregations()), - aggregation.getGroupingSets(), - aggregation.getPreGroupedSymbols(), - AggregationNode.Step.INTERMEDIATE, - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + return AggregationNode.builderFrom(aggregation) + .setId(idAllocator.getNextId()) + .setSource(gatheringExchange) + .setAggregations(outputsAsInputs(aggregation.getAggregations())) + .setStep(AggregationNode.Step.INTERMEDIATE) + .build(); } /** diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java index f33e6918ff26..d0ff406f812c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AggregationDecorrelation.java @@ -13,15 +13,19 @@ */ package io.trino.sql.planner.iterative.rule; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.PlanNode; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; + class AggregationDecorrelation { private AggregationDecorrelation() {} @@ -51,4 +55,26 @@ public static Map rewriteWithMasks(Map return rewritten.buildOrThrow(); } + + /** + * Creates distinct aggregation node based on existing distinct aggregation node. + * + * @see #isDistinctOperator(PlanNode) + */ + public static AggregationNode restoreDistinctAggregation( + AggregationNode distinct, + PlanNode source, + List groupingKeys) + { + checkArgument(isDistinctOperator(distinct)); + return new AggregationNode( + distinct.getId(), + source, + ImmutableMap.of(), + AggregationNode.singleGroupingSet(groupingKeys), + ImmutableList.of(), + distinct.getStep(), + Optional.empty(), + Optional.empty()); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java index d0807c088d51..b21137418df5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.java @@ -53,6 +53,7 @@ import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; @@ -337,15 +338,11 @@ private static AggregationNode withGroupingAndMask(AggregationNode aggregationNo .build()); } - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, rewriteWithMasks(aggregationNode.getAggregations(), masks.buildOrThrow()), - singleGroupingSet(groupingSymbols), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(groupingSymbols)); } private static AggregationNode withGrouping(AggregationNode aggregationNode, List groupingSymbols, PlanNode source) @@ -354,14 +351,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis .distinct() .collect(toImmutableList())); - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, aggregationNode.getAggregations(), - groupingSet, - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + groupingSet); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java index a4705a673911..fd95414d8ae1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateLeftUnnestWithGlobalAggregation.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import io.trino.matching.Captures; @@ -42,6 +41,7 @@ import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation; @@ -264,14 +264,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis .distinct() .collect(toImmutableList())); - return new AggregationNode( + return singleAggregation( aggregationNode.getId(), source, aggregationNode.getAggregations(), - groupingSet, - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + groupingSet); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index 8684ba663e5f..ee20ff805798 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -209,15 +209,9 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context } } if (anyRewritten) { - return Result.ofPlanNode(new AggregationNode( - aggregationNode.getId(), - aggregationNode.getSource(), - aggregations.buildOrThrow(), - aggregationNode.getGroupingSets(), - aggregationNode.getPreGroupedSymbols(), - aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol())); + return Result.ofPlanNode(AggregationNode.builderFrom(aggregationNode) + .setAggregations(aggregations.buildOrThrow()) + .build()); } return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java index f2e34133392b..638e8331f4e1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -140,20 +140,17 @@ else if (mask.isPresent()) { newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols()); return Result.ofPlanNode( - new AggregationNode( - context.getIdAllocator().getNextId(), - new FilterNode( + AggregationNode.builderFrom(aggregationNode) + .setId(context.getIdAllocator().getNextId()) + .setSource(new FilterNode( context.getIdAllocator().getNextId(), new ProjectNode( context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), - predicate), - aggregations.buildOrThrow(), - aggregationNode.getGroupingSets(), - ImmutableList.of(), - aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol())); + predicate)) + .setAggregations(aggregations.buildOrThrow()) + .setPreGroupedSymbols(ImmutableList.of()) + .build()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java index dc5a16b7a438..d132adf1741c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java @@ -159,14 +159,10 @@ public Result apply(AggregationNode parent, Captures captures, Context context) } return Result.ofPlanNode( - new AggregationNode( - parent.getId(), - subPlan, - newAggregations, - parent.getGroupingSets(), - ImmutableList.of(), - parent.getStep(), - parent.getHashSymbol(), - parent.getGroupIdSymbol())); + AggregationNode.builderFrom(parent) + .setSource(subPlan) + .setAggregations(newAggregations) + .setPreGroupedSymbols(ImmutableList.of()) + .build()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneAggregationColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneAggregationColumns.java index 7e63d714cbf4..4545882f5131 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneAggregationColumns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneAggregationColumns.java @@ -48,14 +48,8 @@ protected Optional pushDownProjectOff( // PruneAggregationSourceColumns will subsequently project off any newly unused inputs. return Optional.of( - new AggregationNode( - aggregationNode.getId(), - aggregationNode.getSource(), - prunedAggregations, - aggregationNode.getGroupingSets(), - aggregationNode.getPreGroupedSymbols(), - aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol())); + AggregationNode.builderFrom(aggregationNode) + .setAggregations(prunedAggregations) + .build()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java index 8a6452274190..7f1f2eb545dc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneDistinctAggregation.java @@ -139,15 +139,10 @@ public PlanNode visitAggregation(AggregationNode node, Boolean context) return rewrittenNode; } - return new AggregationNode( - node.getId(), - rewrittenNode, - node.getAggregations(), - node.getGroupingSets(), - ImmutableList.of(), - node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol()); + return AggregationNode.builderFrom(node) + .setSource(rewrittenNode) + .setPreGroupedSymbols(ImmutableList.of()) + .build(); } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInAggregation.java index d5e6bb94bff5..57d1c70af649 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInAggregation.java @@ -78,14 +78,8 @@ else if (metadata.getAggregationFunctionMetadata(context.getSession(), aggregati if (!anyRewritten) { return Result.empty(); } - return Result.ofPlanNode(new AggregationNode( - node.getId(), - node.getSource(), - aggregations.buildOrThrow(), - node.getGroupingSets(), - node.getPreGroupedSymbols(), - node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol())); + return Result.ofPlanNode(AggregationNode.builderFrom(node) + .setAggregations(aggregations.buildOrThrow()) + .build()); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 502aeea97d04..af67662b549f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -55,6 +55,7 @@ import static io.trino.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct; import static io.trino.sql.planner.optimizations.SymbolMapper.symbolMapper; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.aggregation; import static io.trino.sql.planner.plan.Patterns.join; @@ -137,15 +138,11 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont List groupingKeys = join.getCriteria().stream() .map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight) .collect(toImmutableList()); - AggregationNode rewrittenAggregation = new AggregationNode( - aggregation.getId(), - getInnerTable(join), - aggregation.getAggregations(), - singleGroupingSet(groupingKeys), - ImmutableList.of(), - aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + AggregationNode rewrittenAggregation = AggregationNode.builderFrom(aggregation) + .setSource(getInnerTable(join)) + .setGroupingSets(singleGroupingSet(groupingKeys)) + .setPreGroupedSymbols(ImmutableList.of()) + .build(); JoinNode rewrittenJoin; if (join.getType() == JoinNode.Type.LEFT) { @@ -309,15 +306,11 @@ private MappedAggregationInfo createAggregationOverNull(AggregationNode referenc Map aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow(); // create an aggregation node whose source is the null row. - AggregationNode aggregationOverNullRow = new AggregationNode( + AggregationNode aggregationOverNullRow = singleAggregation( idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + globalAggregation()); return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java index e9f62749b2f1..97612f283c78 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushFilterThroughCountAggregation.java @@ -212,15 +212,10 @@ private static Result pushFilter(FilterNode filterNode, AggregationNode aggregat aggregation.getOrderingScheme(), Optional.empty()); - AggregationNode newAggregationNode = new AggregationNode( - aggregationNode.getId(), - source, - ImmutableMap.of(countSymbol, newAggregation), - aggregationNode.getGroupingSets(), - aggregationNode.getPreGroupedSymbols(), - aggregationNode.getStep(), - aggregationNode.getHashSymbol(), - aggregationNode.getGroupIdSymbol()); + AggregationNode newAggregationNode = AggregationNode.builderFrom(aggregationNode) + .setSource(source) + .setAggregations(ImmutableMap.of(countSymbol, newAggregation)) + .build(); // Restore identity projection if it is present in the original plan. PlanNode filterSource = projectNode.map(project -> project.replaceChildren(ImmutableList.of(newAggregationNode))).orElse(newAggregationNode); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java index 1cf62c6ec881..4e0d9bcb5020 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java @@ -127,11 +127,11 @@ private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, C private Set getJoinRequiredSymbols(JoinNode node) { return Streams.concat( - node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), - node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), - node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(), - node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), - node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), + node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(), + node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), + node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) .collect(toImmutableSet()); } @@ -158,15 +158,11 @@ private AggregationNode replaceAggregationSource( PlanNode source, List groupingKeys) { - return new AggregationNode( - aggregation.getId(), - source, - aggregation.getAggregations(), - singleGroupingSet(groupingKeys), - ImmutableList.of(), - aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol()); + return AggregationNode.builderFrom(aggregation) + .setSource(source) + .setGroupingSets(singleGroupingSet(groupingKeys)) + .setPreGroupedSymbols(ImmutableList.of()) + .build(); } private PlanNode pushPartialToJoin( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java index ea25275a0506..8a263bbc75ec 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveEmptyExceptBranches.java @@ -21,8 +21,6 @@ import io.trino.matching.Pattern; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; -import io.trino.sql.planner.plan.AggregationNode; -import io.trino.sql.planner.plan.AggregationNode.Step; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.ExceptNode; import io.trino.sql.planner.plan.PlanNode; @@ -30,9 +28,9 @@ import io.trino.sql.planner.plan.ValuesNode; import java.util.List; -import java.util.Optional; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isEmpty; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.except; @@ -96,18 +94,14 @@ public Result apply(ExceptNode node, Captures captures, Context context) if (node.isDistinct()) { return Result.ofPlanNode( - new AggregationNode( + singleAggregation( node.getId(), new ProjectNode( context.getIdAllocator().getNextId(), newSources.get(0), assignments.build()), ImmutableMap.of(), - singleGroupingSet(node.getOutputSymbols()), - ImmutableList.of(), - Step.SINGLE, - Optional.empty(), - Optional.empty())); + singleGroupingSet(node.getOutputSymbols()))); } return Result.ofPlanNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index ff9c7a6582ef..0e64c07fdfd5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -133,22 +133,17 @@ public Result apply(AggregationNode node, Captures captures, Context context) partitionCount = getHashPartitionCount(context.getSession()); } return Result.ofPlanNode( - new AggregationNode( - node.getId(), - new ProjectNode( + AggregationNode.builderFrom(node) + .setSource(new ProjectNode( context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder() .putIdentities(node.getSource().getOutputSymbols()) .put(partitionCountSymbol, new LongLiteral(Integer.toString(partitionCount))) .putAll(envelopeAssignments.buildOrThrow()) - .build()), - aggregations.buildOrThrow(), - node.getGroupingSets(), - node.getPreGroupedSymbols(), - node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol())); + .build())) + .setAggregations(aggregations.buildOrThrow()) + .build()); } private boolean isStEnvelopeFunctionCall(Expression expression, ResolvedFunction stEnvelopeFunction) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java index 64b9cef1fdfb..0b37ba9d9e0c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java @@ -49,6 +49,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; @@ -180,14 +181,10 @@ private AggregationNode computeCounts(UnionNode sourceNode, List origina Optional.empty())); } - return new AggregationNode(idAllocator.getNextId(), + return singleAggregation(idAllocator.getNextId(), sourceNode, aggregations.buildOrThrow(), - singleGroupingSet(originalColumns), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(originalColumns)); } private WindowNode appendCounts(UnionNode sourceNode, List originalColumns, List markers, List countOutputs, Symbol rowNumberSymbol) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java index d8a25f30c0b7..63bba215993a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -98,15 +98,11 @@ public Result apply(AggregationNode parent, Captures captures, Context context) return Result.empty(); } - return Result.ofPlanNode(new AggregationNode( - parent.getId(), - child, - aggregations, - parent.getGroupingSets(), - ImmutableList.of(), - parent.getStep(), - parent.getHashSymbol(), - parent.getGroupIdSymbol())); + return Result.ofPlanNode(AggregationNode.builderFrom(parent) + .setSource(child) + .setAggregations(aggregations) + .setPreGroupedSymbols(ImmutableList.of()) + .build()); } private boolean isCountOverConstant(Session session, AggregationNode.Aggregation aggregation, Assignments inputs) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index 9475cf135700..f8d157873d5f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -27,16 +27,14 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.aggregation; -import static java.util.Collections.emptyList; /** * Implements distinct aggregations with similar inputs by transforming plans of the following shape: @@ -123,31 +121,25 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont .collect(Collectors.toSet()); return Result.ofPlanNode( - new AggregationNode( - aggregation.getId(), - new AggregationNode( - context.getIdAllocator().getNextId(), - aggregation.getSource(), - ImmutableMap.of(), - singleGroupingSet(ImmutableList.builder() - .addAll(aggregation.getGroupingKeys()) - .addAll(symbols) - .build()), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()), - // remove DISTINCT flag from function calls - aggregation.getAggregations() - .entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - e -> removeDistinct(e.getValue()))), - aggregation.getGroupingSets(), - emptyList(), - aggregation.getStep(), - aggregation.getHashSymbol(), - aggregation.getGroupIdSymbol())); + AggregationNode.builderFrom(aggregation) + .setSource( + singleAggregation( + context.getIdAllocator().getNextId(), + aggregation.getSource(), + ImmutableMap.of(), + singleGroupingSet(ImmutableList.builder() + .addAll(aggregation.getGroupingKeys()) + .addAll(symbols) + .build()))) + .setAggregations( + // remove DISTINCT flag from function calls + aggregation.getAggregations() + .entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> removeDistinct(e.getValue())))) + .setPreGroupedSymbols(ImmutableList.of()) + .build()); } private static Aggregation removeDistinct(Aggregation aggregation) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java index 24fc9970c7ff..b23e8404258b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedDistinctAggregationWithoutProjection.java @@ -130,18 +130,19 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co // restore aggregation AggregationNode aggregation = captures.get(AGGREGATION); - aggregation = new AggregationNode( - aggregation.getId(), - join, - aggregation.getAggregations(), - singleGroupingSet(ImmutableList.builder() - .addAll(join.getLeftOutputSymbols()) - .addAll(aggregation.getGroupingKeys()) - .build()), - ImmutableList.of(), - aggregation.getStep(), - Optional.empty(), - Optional.empty()); + aggregation = AggregationNode.builderFrom(aggregation) + .setSource(join) + .setGroupingSets( + singleGroupingSet(ImmutableList.builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(aggregation.getGroupingKeys()) + .build())) + .setPreGroupedSymbols( + ImmutableList.of()) + .setHashSymbol( + Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); // restrict outputs Optional project = restrictOutputs(context.getIdAllocator(), aggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java index 297d1f92ca87..e6b3acce8bc9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -48,6 +48,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; @@ -204,19 +205,14 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co // restore distinct aggregation if (distinct != null) { - root = new AggregationNode( - distinct.getId(), + root = restoreDistinctAggregation( + distinct, join, - distinct.getAggregations(), - singleGroupingSet(ImmutableList.builder() + ImmutableList.builder() .addAll(join.getLeftOutputSymbols()) .add(nonNull) .addAll(distinct.getGroupingKeys()) - .build()), - ImmutableList.of(), - distinct.getStep(), - Optional.empty(), - Optional.empty()); + .build()); } // prepare mask symbols for aggregations diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java index 09295cd48303..4cb37fdb4685 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java @@ -45,6 +45,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -197,19 +198,14 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co // restore distinct aggregation if (distinct != null) { - root = new AggregationNode( - distinct.getId(), + root = restoreDistinctAggregation( + distinct, join, - distinct.getAggregations(), - singleGroupingSet(ImmutableList.builder() + ImmutableList.builder() .addAll(join.getLeftOutputSymbols()) .add(nonNull) .addAll(distinct.getGroupingKeys()) - .build()), - ImmutableList.of(), - distinct.getStep(), - Optional.empty(), - Optional.empty()); + .build()); } // prepare mask symbols for aggregations diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java index 8ad57dffdeb3..59ead81e287d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithProjection.java @@ -41,6 +41,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.Patterns.Aggregation.groupingColumns; @@ -182,34 +183,27 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co // restore distinct aggregation if (distinct != null) { - distinct = new AggregationNode( - distinct.getId(), + distinct = restoreDistinctAggregation( + distinct, join, - distinct.getAggregations(), - singleGroupingSet(ImmutableList.builder() + ImmutableList.builder() .addAll(join.getLeftOutputSymbols()) .addAll(distinct.getGroupingKeys()) - .build()), - ImmutableList.of(), - distinct.getStep(), - Optional.empty(), - Optional.empty()); + .build()); } // restore grouped aggregation AggregationNode groupedAggregation = captures.get(AGGREGATION); - groupedAggregation = new AggregationNode( - groupedAggregation.getId(), - distinct != null ? distinct : join, - groupedAggregation.getAggregations(), - singleGroupingSet(ImmutableList.builder() + groupedAggregation = AggregationNode.builderFrom(groupedAggregation) + .setSource(distinct != null ? distinct : join) + .setGroupingSets(singleGroupingSet(ImmutableList.builder() .addAll(join.getLeftOutputSymbols()) .addAll(groupedAggregation.getGroupingKeys()) - .build()), - ImmutableList.of(), - groupedAggregation.getStep(), - Optional.empty(), - Optional.empty()); + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .setHashSymbol(Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); // restrict outputs and apply projection Set outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java index 4ebfacd4f882..d87919629260 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGroupedAggregationWithoutProjection.java @@ -36,6 +36,7 @@ import static io.trino.matching.Pattern.nonEmpty; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.isDistinctOperator; +import static io.trino.sql.planner.iterative.rule.AggregationDecorrelation.restoreDistinctAggregation; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; @@ -173,34 +174,29 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co // restore distinct aggregation if (distinct != null) { - distinct = new AggregationNode( - distinct.getId(), + distinct = restoreDistinctAggregation( + distinct, join, - distinct.getAggregations(), - singleGroupingSet(ImmutableList.builder() + ImmutableList.builder() .addAll(join.getLeftOutputSymbols()) .addAll(distinct.getGroupingKeys()) - .build()), - ImmutableList.of(), - distinct.getStep(), - Optional.empty(), - Optional.empty()); + .build()); } // restore grouped aggregation AggregationNode groupedAggregation = captures.get(AGGREGATION); - groupedAggregation = new AggregationNode( - groupedAggregation.getId(), - distinct != null ? distinct : join, - groupedAggregation.getAggregations(), - singleGroupingSet(ImmutableList.builder() - .addAll(join.getLeftOutputSymbols()) - .addAll(groupedAggregation.getGroupingKeys()) - .build()), - ImmutableList.of(), - groupedAggregation.getStep(), - Optional.empty(), - Optional.empty()); + groupedAggregation = AggregationNode.builderFrom(groupedAggregation) + .setSource(distinct != null ? distinct : join) + .setAggregations(groupedAggregation.getAggregations()) + .setGroupingSets( + singleGroupingSet(ImmutableList.builder() + .addAll(join.getLeftOutputSymbols()) + .addAll(groupedAggregation.getGroupingKeys()) + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .setHashSymbol(Optional.empty()) + .setGroupIdSymbol(Optional.empty()) + .build(); // restrict outputs Optional project = restrictOutputs(context.getIdAllocator(), groupedAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 8da7a2ea2919..47ce4c17308a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -64,6 +64,7 @@ import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.or; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.Patterns.Apply.correlation; import static io.trino.sql.planner.plan.Patterns.applyNode; @@ -216,18 +217,14 @@ private PlanNode buildInPredicateEquivalent( Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT); Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT); - AggregationNode aggregation = new AggregationNode( + AggregationNode aggregation = singleAggregation( idAllocator.getNextId(), preProjection, ImmutableMap.builder() .put(countMatchesSymbol, countWithFilter(session, matchConditionSymbol)) .put(countNullMatchesSymbol, countWithFilter(session, nullMatchConditionSymbol)) .buildOrThrow(), - singleGroupingSet(probeSide.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(probeSide.getOutputSymbols())); // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java index 19e66ec4990f..a2e339723706 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformExistsApplyToCorrelatedJoin.java @@ -22,7 +22,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.PlanNodeDecorrelator; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.ApplyNode; import io.trino.sql.planner.plan.Assignments; @@ -47,6 +46,7 @@ import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.INNER; import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT; import static io.trino.sql.planner.plan.Patterns.applyNode; @@ -174,7 +174,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context contex applyNode.getInput(), new ProjectNode( context.getIdAllocator().getNextId(), - new AggregationNode( + singleAggregation( context.getIdAllocator().getNextId(), applyNode.getSubquery(), ImmutableMap.of(count, new Aggregation( @@ -184,11 +184,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context contex Optional.empty(), Optional.empty(), Optional.empty())), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()), + globalAggregation()), Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), toSqlType(BIGINT))))), applyNode.getCorrelation(), INNER, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java index 1e3ea2442777..5cba9ba69b63 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformFilteringSemiJoinToInnerJoin.java @@ -22,7 +22,6 @@ import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.optimizations.PlanNodeSearcher; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; @@ -43,7 +42,7 @@ import static io.trino.sql.ExpressionUtils.and; import static io.trino.sql.ExpressionUtils.extractConjuncts; import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.Patterns.filter; @@ -124,15 +123,11 @@ public Result apply(FilterNode filterNode, Captures captures, Context context) Optional joinFilter = simplifiedPredicate.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(simplifiedPredicate); - PlanNode filteringSourceDistinct = new AggregationNode( + PlanNode filteringSourceDistinct = singleAggregation( context.getIdAllocator().getNextId(), semiJoin.getFilteringSource(), ImmutableMap.of(), - singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol())), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol()))); JoinNode innerJoin = new JoinNode( semiJoin.getId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index 7fac80d83094..522efb8d28fd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -368,15 +368,10 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred preGroupedSymbols = groupingKeys; } - AggregationNode result = new AggregationNode( - node.getId(), - child.getNode(), - node.getAggregations(), - node.getGroupingSets(), - preGroupedSymbols, - node.getStep(), - node.getHashSymbol(), - node.getGroupIdSymbol()); + AggregationNode result = AggregationNode.builderFrom(node) + .setSource(child.getNode()) + .setPreGroupedSymbols(preGroupedSymbols) + .build(); return deriveProperties(result, child.getProperties()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index 24b0b323bf22..c9262e2958db 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -182,15 +182,10 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation Optional hashSymbol = groupByHash.map(child::getRequiredHashSymbol); return new PlanWithProperties( - new AggregationNode( - node.getId(), - child.getNode(), - node.getAggregations(), - node.getGroupingSets(), - node.getPreGroupedSymbols(), - node.getStep(), - hashSymbol, - node.getGroupIdSymbol()), + AggregationNode.builderFrom(node) + .setSource(child.getNode()) + .setHashSymbol(hashSymbol) + .build(), hashSymbol.isPresent() ? ImmutableMap.of(groupByHash.get(), hashSymbol.get()) : ImmutableMap.of()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java index 1638a0436add..b614e92db70c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -64,6 +64,7 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toTypeSignature; import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.optimizations.SymbolMapper.symbolMapper; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.TopNRankingNode.RankingType.ROW_NUMBER; import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; @@ -229,15 +230,11 @@ private Optional rewriteLimitWithRowCountOne(DecorrelationR } // rewrite Limit to aggregation on constant symbols - AggregationNode aggregationNode = new AggregationNode( + AggregationNode aggregationNode = singleAggregation( nodeId, decorrelatedChildNode, ImmutableMap.of(), - singleGroupingSet(decorrelatedChildNode.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(decorrelatedChildNode.getOutputSymbols())); return Optional.of(new DecorrelationResult( aggregationNode, @@ -439,18 +436,14 @@ public Optional visitAggregation(AggregationNode node, Void return Optional.empty(); } - AggregationNode newAggregation = new AggregationNode( - decorrelatedAggregation.getId(), - decorrelatedAggregation.getSource(), - decorrelatedAggregation.getAggregations(), - AggregationNode.singleGroupingSet(ImmutableList.builder() - .addAll(node.getGroupingKeys()) - .addAll(symbolsToAdd) - .build()), - ImmutableList.of(), - decorrelatedAggregation.getStep(), - decorrelatedAggregation.getHashSymbol(), - decorrelatedAggregation.getGroupIdSymbol()); + AggregationNode newAggregation = AggregationNode.builderFrom(decorrelatedAggregation) + .setGroupingSets( + AggregationNode.singleGroupingSet(ImmutableList.builder() + .addAll(node.getGroupingKeys()) + .addAll(symbolsToAdd) + .build())) + .setPreGroupedSymbols(ImmutableList.of()) + .build(); return Optional.of(new DecorrelationResult( newAggregation, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 130577a49d44..4ab16f218e16 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -612,33 +612,33 @@ private DynamicFiltersResult createDynamicFilters( } List clauses = Streams.concat( - equiJoinClauses - .stream() - .map(clause -> new DynamicFilterExpression( - new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()))), - joinFilterClauses.stream() - .flatMap(Rewriter::tryConvertBetweenIntoComparisons) - .filter(clause -> joinDynamicFilteringExpression(clause, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) - .map(expression -> { - if (expression instanceof NotExpression) { - NotExpression notExpression = ((NotExpression) expression); - ComparisonExpression comparison = (ComparisonExpression) notExpression.getValue(); - return new DynamicFilterExpression(new ComparisonExpression(EQUAL, comparison.getLeft(), comparison.getRight()), true); - } - return new DynamicFilterExpression((ComparisonExpression) expression); - }) - .map(expression -> { - ComparisonExpression comparison = expression.getComparison(); - Expression leftExpression = comparison.getLeft(); - Expression rightExpression = comparison.getRight(); - boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression)); - return new DynamicFilterExpression( - new ComparisonExpression( - alignedComparison ? comparison.getOperator() : comparison.getOperator().flip(), - alignedComparison ? leftExpression : rightExpression, - alignedComparison ? rightExpression : leftExpression), - expression.isNullAllowed()); - })) + equiJoinClauses + .stream() + .map(clause -> new DynamicFilterExpression( + new ComparisonExpression(EQUAL, clause.getLeft().toSymbolReference(), clause.getRight().toSymbolReference()))), + joinFilterClauses.stream() + .flatMap(Rewriter::tryConvertBetweenIntoComparisons) + .filter(clause -> joinDynamicFilteringExpression(clause, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) + .map(expression -> { + if (expression instanceof NotExpression) { + NotExpression notExpression = ((NotExpression) expression); + ComparisonExpression comparison = (ComparisonExpression) notExpression.getValue(); + return new DynamicFilterExpression(new ComparisonExpression(EQUAL, comparison.getLeft(), comparison.getRight()), true); + } + return new DynamicFilterExpression((ComparisonExpression) expression); + }) + .map(expression -> { + ComparisonExpression comparison = expression.getComparison(); + Expression leftExpression = comparison.getLeft(); + Expression rightExpression = comparison.getRight(); + boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression)); + return new DynamicFilterExpression( + new ComparisonExpression( + alignedComparison ? comparison.getOperator() : comparison.getOperator().flip(), + alignedComparison ? leftExpression : rightExpression, + alignedComparison ? rightExpression : leftExpression), + expression.isNullAllowed()); + })) .collect(toImmutableList()); // New equiJoinClauses could potentially not contain symbols used in current dynamic filters. @@ -1509,14 +1509,10 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext outputColumnReferences = ImmutableList.of(outputColumn.toSymbolReference()); - subqueryPlan = new AggregationNode( + subqueryPlan = singleAggregation( idAllocator.getNextId(), subqueryPlan, ImmutableMap.of( @@ -172,11 +172,7 @@ countNonNullValue, new Aggregation( Optional.empty(), Optional.empty(), Optional.empty())), - globalAggregation(), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + globalAggregation()); PlanNode join = new CorrelatedJoinNode( node.getId(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java index ab98d65001be..a152160e81c7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java @@ -56,6 +56,15 @@ public class AggregationNode private final Optional groupIdSymbol; private final List outputs; + public static AggregationNode singleAggregation( + PlanNodeId id, + PlanNode source, + Map aggregations, + GroupingSetDescriptor groupingSets) + { + return new AggregationNode(id, source, aggregations, groupingSets, ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()); + } + @JsonCreator public AggregationNode( @JsonProperty("id") PlanNodeId id, @@ -207,7 +216,9 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new AggregationNode(getId(), Iterables.getOnlyElement(newChildren), aggregations, groupingSets, preGroupedSymbols, step, hashSymbol, groupIdSymbol); + return builderFrom(this) + .setSource(Iterables.getOnlyElement(newChildren)) + .build(); } public boolean producesDistinctRows() @@ -479,4 +490,95 @@ private void verifyArguments(Step step) arguments.size()); } } + + public static Builder builderFrom(AggregationNode node) + { + return new Builder(node); + } + + public static class Builder + { + private PlanNodeId id; + private PlanNode source; + private Map aggregations; + private GroupingSetDescriptor groupingSets; + private List preGroupedSymbols; + private Step step; + private Optional hashSymbol; + private Optional groupIdSymbol; + + public Builder(AggregationNode node) + { + requireNonNull(node, "node is null"); + this.id = node.getId(); + this.source = node.getSource(); + this.aggregations = node.getAggregations(); + this.groupingSets = node.getGroupingSets(); + this.preGroupedSymbols = node.getPreGroupedSymbols(); + this.step = node.getStep(); + this.hashSymbol = node.getHashSymbol(); + this.groupIdSymbol = node.getGroupIdSymbol(); + } + + public Builder setId(PlanNodeId id) + { + this.id = requireNonNull(id, "id is null"); + return this; + } + + public Builder setSource(PlanNode source) + { + this.source = requireNonNull(source, "source is null"); + return this; + } + + public Builder setAggregations(Map aggregations) + { + this.aggregations = requireNonNull(aggregations, "aggregations is null"); + return this; + } + + public Builder setGroupingSets(GroupingSetDescriptor groupingSets) + { + this.groupingSets = requireNonNull(groupingSets, "groupingSets is null"); + return this; + } + + public Builder setPreGroupedSymbols(List preGroupedSymbols) + { + this.preGroupedSymbols = requireNonNull(preGroupedSymbols, "preGroupedSymbols is null"); + return this; + } + + public Builder setStep(Step step) + { + this.step = requireNonNull(step, "step is null"); + return this; + } + + public Builder setHashSymbol(Optional hashSymbol) + { + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + return this; + } + + public Builder setGroupIdSymbol(Optional groupIdSymbol) + { + this.groupIdSymbol = requireNonNull(groupIdSymbol, "groupIdSymbol is null"); + return this; + } + + public AggregationNode build() + { + return new AggregationNode( + id, + source, + aggregations, + groupingSets, + preGroupedSymbols, + step, + hashSymbol, + groupIdSymbol); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java index b796d47c1233..bcbaf76c8c75 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestCostCalculator.java @@ -71,6 +71,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; @@ -824,15 +825,11 @@ private AggregationNode aggregation(String id, PlanNode source) Optional.empty(), Optional.empty()); - return new AggregationNode( + return singleAggregation( new PlanNodeId(id), source, ImmutableMap.of(new Symbol("count"), aggregation), - singleGroupingSet(source.getOutputSymbols()), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(source.getOutputSymbols())); } /** diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java index ff39b5360640..0ba894957661 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestEffectivePredicateExtractor.java @@ -106,6 +106,7 @@ import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; import static io.trino.sql.planner.plan.AggregationNode.globalAggregation; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL; import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -209,7 +210,7 @@ public void setUp() @Test public void testAggregation() { - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), filter( baseTableScan, @@ -236,11 +237,7 @@ D, new Aggregation( Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(A, B, C)), - ImmutableList.of(), - AggregationNode.Step.SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(A, B, C))); Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java index 0236be628794..52ed4b72c118 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java @@ -24,7 +24,6 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.VarcharType; -import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.Assignments; import io.trino.sql.planner.plan.PlanNode; @@ -57,7 +56,7 @@ import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.sql.planner.TypeAnalyzer.createTestingTypeAnalyzer; -import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -178,7 +177,7 @@ public void testValidAggregation() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -188,11 +187,7 @@ public void testValidAggregation() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertTypesValid(node); } @@ -234,7 +229,7 @@ public void testInvalidAggregationFunctionCall() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -244,11 +239,7 @@ public void testInvalidAggregationFunctionCall() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertThatThrownBy(() -> assertTypesValid(node)) .isInstanceOf(IllegalArgumentException.class) @@ -260,7 +251,7 @@ public void testInvalidAggregationFunctionSignature() { Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", BIGINT); - PlanNode node = new AggregationNode( + PlanNode node = singleAggregation( newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( @@ -270,11 +261,7 @@ public void testInvalidAggregationFunctionSignature() Optional.empty(), Optional.empty(), Optional.empty())), - singleGroupingSet(ImmutableList.of(columnA, columnB)), - ImmutableList.of(), - SINGLE, - Optional.empty(), - Optional.empty()); + singleGroupingSet(ImmutableList.of(columnA, columnB))); assertThatThrownBy(() -> assertTypesValid(node)) .isInstanceOf(IllegalArgumentException.class)