Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DeleteNode;
import io.trino.sql.planner.plan.ExplainAnalyzeNode;
Expand Down Expand Up @@ -134,6 +133,7 @@
import static io.trino.sql.planner.PlanBuilder.newPlanBuilder;
import static io.trino.sql.planner.QueryPlanner.visibleFields;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.TableWriterNode.CreateReference;
import static io.trino.sql.planner.plan.TableWriterNode.InsertReference;
Expand Down Expand Up @@ -362,15 +362,11 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme

PlanNode planNode = new StatisticsWriterNode(
idAllocator.getNextId(),
new AggregationNode(
singleAggregation(
idAllocator.getNextId(),
TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.buildOrThrow(), false, Optional.empty()),
statisticAggregations.getAggregations(),
singleGroupingSet(groupingSymbols),
ImmutableList.of(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty()),
singleGroupingSet(groupingSymbols)),
new StatisticsWriterNode.WriteStatisticsReference(targetTable),
symbolAllocator.newSymbol("rows", BIGINT),
tableStatisticsMetadata.getTableStatistics().contains(ROW_COUNT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ protected PlanNode visitPlan(PlanNode node, RewriteContext<Void> context)
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
{
return new AggregationNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol());
return AggregationNode.builderFrom(node)
.setId(idAllocator.getNextId())
.setSource(context.rewrite(node.getSource()))
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import static io.trino.sql.planner.PlanBuilder.newPlanBuilder;
import static io.trino.sql.planner.ScopeAware.scopeAwareKey;
import static io.trino.sql.planner.plan.AggregationNode.groupingSets;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
Expand Down Expand Up @@ -333,15 +334,11 @@ public RelationPlan planExpand(Query query)
PlanNode result = new UnionNode(idAllocator.getNextId(), nodesToUnion, unionSymbolMapping.build(), unionOutputSymbols);

if (union.isDistinct()) {
result = new AggregationNode(
result = singleAggregation(
idAllocator.getNextId(),
result,
ImmutableMap.of(),
singleGroupingSet(result.getOutputSymbols()),
ImmutableList.of(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty());
singleGroupingSet(result.getOutputSymbols()));
}

return new RelationPlan(result, anchorPlan.getScope(), unionOutputSymbols, outerContext);
Expand Down Expand Up @@ -1654,15 +1651,11 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node, List<
.collect(Collectors.toList());

return subPlan.withNewRoot(
new AggregationNode(
singleAggregation(
idAllocator.getNextId(),
subPlan.getRoot(),
ImmutableMap.of(),
singleGroupingSet(symbols),
ImmutableList.of(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty()));
singleGroupingSet(symbols)));
}

return subPlan;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.trino.sql.analyzer.Field;
import io.trino.sql.analyzer.RelationType;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.ExceptNode;
Expand Down Expand Up @@ -122,6 +121,7 @@
import static io.trino.sql.planner.QueryPlanner.extractPatternRecognitionExpressions;
import static io.trino.sql.planner.QueryPlanner.planWindowSpecification;
import static io.trino.sql.planner.QueryPlanner.pruneInvisibleFields;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static io.trino.sql.tree.Join.Type.CROSS;
Expand Down Expand Up @@ -1160,14 +1160,10 @@ private SetOperationPlan process(SetOperation node)

private PlanNode distinct(PlanNode node)
{
return new AggregationNode(idAllocator.getNextId(),
return singleAggregation(idAllocator.getNextId(),
node,
ImmutableMap.of(),
singleGroupingSet(node.getOutputSymbols()),
ImmutableList.of(),
AggregationNode.Step.SINGLE,
Optional.empty(),
Optional.empty());
singleGroupingSet(node.getOutputSymbols()));
}

private static final class SetOperationPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,12 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI
{
verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
return new AggregationNode(
idAllocator.getNextId(),
gatheringExchange,
outputsAsInputs(aggregation.getAggregations()),
aggregation.getGroupingSets(),
aggregation.getPreGroupedSymbols(),
AggregationNode.Step.INTERMEDIATE,
aggregation.getHashSymbol(),
aggregation.getGroupIdSymbol());
return AggregationNode.builderFrom(aggregation)
.setId(idAllocator.getNextId())
.setSource(gatheringExchange)
.setAggregations(outputsAsInputs(aggregation.getAggregations()))
.setStep(AggregationNode.Step.INTERMEDIATE)
.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AggregationNode.Aggregation;
import io.trino.sql.planner.plan.PlanNode;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;

class AggregationDecorrelation
{
private AggregationDecorrelation() {}
Expand Down Expand Up @@ -51,4 +55,26 @@ public static Map<Symbol, Aggregation> rewriteWithMasks(Map<Symbol, Aggregation>

return rewritten.buildOrThrow();
}

/**
* Creates distinct aggregation node based on existing distinct aggregation node.
*
* @see #isDistinctOperator(PlanNode)
*/
public static AggregationNode restoreDistinctAggregation(
Comment thread
lukasz-stec marked this conversation as resolved.
Outdated
AggregationNode distinct,
PlanNode source,
List<Symbol> groupingKeys)
{
checkArgument(isDistinctOperator(distinct));
return new AggregationNode(
Comment thread
lukasz-stec marked this conversation as resolved.
Outdated
distinct.getId(),
source,
ImmutableMap.of(),
AggregationNode.singleGroupingSet(groupingKeys),
ImmutableList.of(),
distinct.getStep(),
Optional.empty(),
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
Expand Down Expand Up @@ -337,15 +338,11 @@ private static AggregationNode withGroupingAndMask(AggregationNode aggregationNo
.build());
}

return new AggregationNode(
return singleAggregation(
aggregationNode.getId(),
source,
rewriteWithMasks(aggregationNode.getAggregations(), masks.buildOrThrow()),
singleGroupingSet(groupingSymbols),
ImmutableList.of(),
SINGLE,
Optional.empty(),
Optional.empty());
singleGroupingSet(groupingSymbols));
}

private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> groupingSymbols, PlanNode source)
Expand All @@ -354,14 +351,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis
.distinct()
.collect(toImmutableList()));

return new AggregationNode(
return singleAggregation(
aggregationNode.getId(),
source,
aggregationNode.getAggregations(),
groupingSet,
ImmutableList.of(),
SINGLE,
Optional.empty(),
Optional.empty());
groupingSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
Expand Down Expand Up @@ -42,6 +41,7 @@
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.JoinNode.Type.LEFT;
import static io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation;
Expand Down Expand Up @@ -264,14 +264,10 @@ private static AggregationNode withGrouping(AggregationNode aggregationNode, Lis
.distinct()
.collect(toImmutableList()));

return new AggregationNode(
return singleAggregation(
aggregationNode.getId(),
source,
aggregationNode.getAggregations(),
groupingSet,
ImmutableList.of(),
SINGLE,
Optional.empty(),
Optional.empty());
groupingSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,9 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
}
}
if (anyRewritten) {
return Result.ofPlanNode(new AggregationNode(
aggregationNode.getId(),
aggregationNode.getSource(),
aggregations.buildOrThrow(),
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
return Result.ofPlanNode(AggregationNode.builderFrom(aggregationNode)
.setAggregations(aggregations.buildOrThrow())
.build());
}
return Result.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,17 @@ else if (mask.isPresent()) {
newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());

return Result.ofPlanNode(
new AggregationNode(
context.getIdAllocator().getNextId(),
new FilterNode(
AggregationNode.builderFrom(aggregationNode)
.setId(context.getIdAllocator().getNextId())
.setSource(new FilterNode(
context.getIdAllocator().getNextId(),
new ProjectNode(
context.getIdAllocator().getNextId(),
aggregationNode.getSource(),
newAssignments.build()),
predicate),
aggregations.buildOrThrow(),
aggregationNode.getGroupingSets(),
ImmutableList.of(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
predicate))
.setAggregations(aggregations.buildOrThrow())
.setPreGroupedSymbols(ImmutableList.of())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,10 @@ public Result apply(AggregationNode parent, Captures captures, Context context)
}

return Result.ofPlanNode(
new AggregationNode(
parent.getId(),
subPlan,
newAggregations,
parent.getGroupingSets(),
ImmutableList.of(),
parent.getStep(),
parent.getHashSymbol(),
parent.getGroupIdSymbol()));
AggregationNode.builderFrom(parent)
.setSource(subPlan)
.setAggregations(newAggregations)
.setPreGroupedSymbols(ImmutableList.of())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,8 @@ protected Optional<PlanNode> pushDownProjectOff(

// PruneAggregationSourceColumns will subsequently project off any newly unused inputs.
return Optional.of(
new AggregationNode(
aggregationNode.getId(),
aggregationNode.getSource(),
prunedAggregations,
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
AggregationNode.builderFrom(aggregationNode)
.setAggregations(prunedAggregations)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,10 @@ public PlanNode visitAggregation(AggregationNode node, Boolean context)
return rewrittenNode;
}

return new AggregationNode(
node.getId(),
rewrittenNode,
node.getAggregations(),
node.getGroupingSets(),
ImmutableList.of(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol());
return AggregationNode.builderFrom(node)
.setSource(rewrittenNode)
.setPreGroupedSymbols(ImmutableList.of())
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,8 @@ else if (metadata.getAggregationFunctionMetadata(context.getSession(), aggregati
if (!anyRewritten) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(
node.getId(),
node.getSource(),
aggregations.buildOrThrow(),
node.getGroupingSets(),
node.getPreGroupedSymbols(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol()));
return Result.ofPlanNode(AggregationNode.builderFrom(node)
.setAggregations(aggregations.buildOrThrow())
.build());
}
}
Loading