diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPlanOptimizer.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPlanOptimizer.java index b7446e5c36af..6558100a9f67 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPlanOptimizer.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPlanOptimizer.java @@ -173,6 +173,7 @@ private AggregationNode simpleAggregationSum(PlanBuilder pb, PlanNode source, Va ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java index e2d52484e231..9c96e9fe9c0c 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveHistoryBasedStatsTracking.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.SqlQueryManager; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.sql.planner.Plan; @@ -32,9 +33,11 @@ import org.testng.annotations.Test; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY; import static com.facebook.presto.SystemSessionProperties.RESTRICT_HISTORY_BASED_OPTIMIZATION_TO_COMPLEX_QUERY; import static com.facebook.presto.SystemSessionProperties.TRACK_HISTORY_BASED_PLAN_STATISTICS; import static com.facebook.presto.SystemSessionProperties.USE_HISTORY_BASED_PLAN_STATISTICS; +import static com.facebook.presto.SystemSessionProperties.USE_PARTIAL_AGGREGATION_HISTORY; import static com.facebook.presto.hive.HiveQueryRunner.HIVE_CATALOG; import static com.facebook.presto.hive.HiveSessionProperties.PUSHDOWN_FILTER_ENABLED; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; @@ -80,7 +83,7 @@ public void testHistoryBasedStatsCalculator() anyTree(node(ProjectNode.class, any())).withOutputRowCount(229.5)); // HBO Statistics - executeAndTrackHistory("SELECT *, 1 FROM test_orders where ds = '2020-09-01' and substr(orderpriority, 1, 1) = '1'"); + executeAndTrackHistory("SELECT *, 1 FROM test_orders where ds = '2020-09-01' and substr(orderpriority, 1, 1) = '1'", defaultSession()); assertPlan( "SELECT *, 2 FROM test_orders where ds = '2020-09-02' and substr(orderpriority, 1, 1) = '1'", anyTree(node(ProjectNode.class, any()).withOutputRowCount(48))); @@ -96,7 +99,7 @@ public void testInsertTable() try { getQueryRunner().execute("CREATE TABLE test_orders (orderkey integer, ds varchar) WITH (partitioned_by = ARRAY['ds'])"); - Plan plan = plan("insert into test_orders (values (1, '2023-09-20'), (2, '2023-09-21'))", createSession()); + Plan plan = plan("insert into test_orders (values (1, '2023-09-20'), (2, '2023-09-21'))", defaultSession()); assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) .where(node -> node instanceof TableWriterMergeNode && !node.getStatsEquivalentPlanNode().isPresent()) @@ -125,7 +128,7 @@ public void testBroadcastJoin() // CBO Statistics Plan plan = plan("SELECT * FROM " + "(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " + - "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession()); + "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", defaultSession()); assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) .where(node -> node instanceof JoinNode && ((JoinNode) node).getDistributionType().get().equals(JoinNode.DistributionType.PARTITIONED)) @@ -135,11 +138,12 @@ public void testBroadcastJoin() // HBO Statistics executeAndTrackHistory("SELECT * FROM " + "(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " + - "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey"); + "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", + defaultSession()); plan = plan("SELECT * FROM " + "(SELECT * FROM test_orders where ds = '2020-09-01' and substr(CAST(custkey AS VARCHAR), 1, 3) <> '370') t1 JOIN " + - "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", createSession()); + "(SELECT * FROM test_orders where ds = '2020-09-02' and substr(CAST(custkey AS VARCHAR), 1, 3) = '370') t2 ON t1.orderkey = t2.orderkey", defaultSession()); assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) .where(node -> node instanceof JoinNode && ((JoinNode) node).getDistributionType().get().equals(JoinNode.DistributionType.REPLICATED)) @@ -151,28 +155,65 @@ public void testBroadcastJoin() } } + @Test + public void testPartialAggStatistics() + { + try { + // CBO Statistics + getQueryRunner().execute("CREATE TABLE test_orders WITH (partitioned_by = ARRAY['ds', 'ts']) AS " + + "SELECT orderkey, orderpriority, comment, custkey, '2020-09-01' as ds, '00:01' as ts FROM orders where orderkey < 2000 "); + + String query = "SELECT count(*) FROM test_orders group by custkey"; + Session session = createSession("always"); + Plan plan = plan(query, session); + + assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) + .where(node -> node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL) + .findFirst() + .isPresent()); + + // collect HBO Statistics + executeAndTrackHistory(query, createSession("always")); + + plan = plan(query, createSession("automatic")); + + assertTrue(PlanNodeSearcher.searchFrom(plan.getRoot()) + .where(node -> node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL).findAll().isEmpty()); + } + finally { + getQueryRunner().execute("DROP TABLE IF EXISTS test_orders"); + } + } + @Override protected void assertPlan(@Language("SQL") String query, PlanMatchPattern pattern) { - assertPlan(createSession(), query, pattern); + assertPlan(defaultSession(), query, pattern); } - private void executeAndTrackHistory(String sql) + private void executeAndTrackHistory(String sql, Session session) { DistributedQueryRunner queryRunner = (DistributedQueryRunner) getQueryRunner(); SqlQueryManager sqlQueryManager = (SqlQueryManager) queryRunner.getCoordinator().getQueryManager(); InMemoryHistoryBasedPlanStatisticsProvider provider = (InMemoryHistoryBasedPlanStatisticsProvider) sqlQueryManager.getHistoryBasedPlanStatisticsTracker().getHistoryBasedPlanStatisticsProvider(); - queryRunner.execute(createSession(), sql); + queryRunner.execute(session, sql); provider.waitProcessQueryEvents(); } - private Session createSession() + private Session defaultSession() + { + return createSession("automatic"); + } + + private Session createSession(String partialAggregationStrategy) { return Session.builder(getQueryRunner().getDefaultSession()) .setSystemProperty(USE_HISTORY_BASED_PLAN_STATISTICS, "true") .setSystemProperty(TRACK_HISTORY_BASED_PLAN_STATISTICS, "true") .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "automatic") + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, partialAggregationStrategy) + .setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true") .setCatalogSessionProperty(HIVE_CATALOG, PUSHDOWN_FILTER_ENABLED, "true") .setSystemProperty(RESTRICT_HISTORY_BASED_OPTIMIZATION_TO_COMPLEX_QUERY, "false") .build(); diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index fb6dea67fc18..6d1511df299c 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -294,6 +294,7 @@ public final class SystemSessionProperties public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression"; public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates"; public static final String ENABLE_HISTORY_BASED_SCALED_WRITER = "enable_history_based_scaled_writer"; + public static final String USE_PARTIAL_AGGREGATION_HISTORY = "use_partial_aggregation_history"; public static final String REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN = "remove_redundant_cast_to_varchar_in_join"; public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins"; @@ -1772,6 +1773,11 @@ public SystemSessionProperties( "Enable setting the initial number of tasks for scaled writers with HBO", featuresConfig.isUseHBOForScaledWriters(), false), + booleanProperty( + USE_PARTIAL_AGGREGATION_HISTORY, + "Use collected partial aggregation statistics from HBO", + featuresConfig.isUsePartialAggregationHistory(), + false), booleanProperty( REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "If both left and right side of join clause are varchar cast from int/bigint, remove the cast here", @@ -2961,6 +2967,11 @@ public static boolean useHistoryBasedScaledWriters(Session session) return session.getSystemProperty(ENABLE_HISTORY_BASED_SCALED_WRITER, Boolean.class); } + public static boolean usePartialAggregationHistory(Session session) + { + return session.getSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, Boolean.class); + } + public static boolean isRemoveRedundantCastToVarcharInJoinEnabled(Session session) { return session.getSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, Boolean.class); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java index 56b7226b4377..716ffea7f49f 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java @@ -22,14 +22,17 @@ import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.StageInfo; import com.facebook.presto.metadata.SessionPropertyManager; +import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeWithHash; +import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.statistics.HistoryBasedSourceInfo; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.PlanStatisticsWithSourceInfo; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; @@ -123,8 +126,10 @@ public Map getQueryStats(QueryIn List allStages = outputStage.getAllStages(); Map planNodeStatsMap = aggregateStageStats(allStages); - Map planStatistics = new HashMap<>(); + Map planStatisticsMap = new HashMap<>(); Map canonicalInfoMap = new HashMap<>(); + Map aggregationNodeMap = new HashMap<>(); + queryInfo.getPlanCanonicalInfo().forEach(canonicalPlanWithInfo -> { // We can have duplicate stats equivalent plan nodes. It's ok to use any stats in this case canonicalInfoMap.putIfAbsent(canonicalPlanWithInfo.getCanonicalPlan(), canonicalPlanWithInfo.getInfo()); @@ -137,19 +142,32 @@ public Map getQueryStats(QueryIn boolean isScaledWriterStage = stageInfo.getPlan().isPresent() && stageInfo.getPlan().get().getPartitioning().equals(SCALED_WRITER_DISTRIBUTION); PlanNode root = stageInfo.getPlan().get().getRoot(); for (PlanNode planNode : forTree(PlanNode::getSources).depthFirstPreOrder(root)) { - if (!planNode.getStatsEquivalentPlanNode().isPresent()) { + if (!planNode.getStatsEquivalentPlanNode().isPresent() && !isAggregation(planNode, AggregationNode.Step.PARTIAL)) { continue; } PlanNodeStats planNodeStats = planNodeStatsMap.get(planNode.getId()); if (planNodeStats == null) { continue; } + double outputPositions = planNodeStats.getPlanNodeOutputPositions(); double outputBytes = adjustedOutputBytes(planNode, planNodeStats); double nullJoinBuildKeyCount = planNodeStats.getPlanNodeNullJoinBuildKeyCount(); double joinBuildKeyCount = planNodeStats.getPlanNodeJoinBuildKeyCount(); double nullJoinProbeKeyCount = planNodeStats.getPlanNodeNullJoinProbeKeyCount(); double joinProbeKeyCount = planNodeStats.getPlanNodeJoinProbeKeyCount(); + PartialAggregationStatistics partialAggregationStatistics = PartialAggregationStatistics.empty(); + + if (isAggregation(planNode, AggregationNode.Step.PARTIAL)) { + // we're doing a depth-first traversal of the plan tree so we must have seen the corresponding final agg already: + // find it and update its partial agg stats + partialAggregationStatistics = constructAggregationNodeStatistics(planNode, planNodeStatsMap, outputBytes, outputPositions); + updatePartialAggregationStatistics((AggregationNode) planNode, aggregationNodeMap, partialAggregationStatistics, planStatisticsMap); + } + + if (!planNode.getStatsEquivalentPlanNode().isPresent()) { + continue; + } JoinNodeStatistics joinNodeStatistics = JoinNodeStatistics.empty(); if (planNode instanceof JoinNode) { @@ -176,21 +194,70 @@ public Map getQueryStats(QueryIn Double.isNaN(outputBytes) ? Estimate.unknown() : Estimate.of(outputBytes), 1.0, joinNodeStatistics, - tableWriterNodeStatistics); - if (planStatistics.containsKey(planNodeWithHash)) { - newPlanNodeStats = planStatistics.get(planNodeWithHash).getPlanStatistics().update(newPlanNodeStats); + tableWriterNodeStatistics, + partialAggregationStatistics); + if (planStatisticsMap.containsKey(planNodeWithHash)) { + newPlanNodeStats = planStatisticsMap.get(planNodeWithHash).getPlanStatistics().update(newPlanNodeStats); + } + PlanStatisticsWithSourceInfo planStatsWithSourceInfo = new PlanStatisticsWithSourceInfo( + planNode.getId(), + newPlanNodeStats, + new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics))); + planStatisticsMap.put(planNodeWithHash, planStatsWithSourceInfo); + + if (isAggregation(planNode, AggregationNode.Step.FINAL) && ((AggregationNode) planNode).getAggregationId().isPresent()) { + // we're doing a depth-first traversal of the plan tree: cache the final agg so that when we encounter the partial agg we can come back + // and update the partial agg statistics + aggregationNodeMap.put(((AggregationNode) planNode).getAggregationId().get(), new FinalAggregationStatsInfo(planNodeWithHash, planStatsWithSourceInfo)); } - planStatistics.put( - planNodeWithHash, - new PlanStatisticsWithSourceInfo( - planNode.getId(), - newPlanNodeStats, - new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics)))); } } } } - return ImmutableMap.copyOf(planStatistics); + return ImmutableMap.copyOf(planStatisticsMap); + } + + private static void updatePartialAggregationStatistics( + AggregationNode partialAggregationNode, + Map aggregationNodeStats, + PartialAggregationStatistics partialAggregationStatistics, + Map planStatisticsMap) + { + if (!partialAggregationNode.getAggregationId().isPresent() || !aggregationNodeStats.containsKey(partialAggregationNode.getAggregationId().get())) { + return; + } + + // find the stats for the matching final aggregation node (the partial and the final node share the same aggregationId) + FinalAggregationStatsInfo finalAggregationStatsInfo = aggregationNodeStats.get(partialAggregationNode.getAggregationId().get()); + PlanStatisticsWithSourceInfo planStatisticsWithSourceInfo = finalAggregationStatsInfo.getPlanStatisticsWithSourceInfo(); + PlanStatistics planStatisticsFinalAgg = planStatisticsWithSourceInfo.getPlanStatistics(); + planStatisticsFinalAgg = planStatisticsFinalAgg.updateAggregationStatistics(partialAggregationStatistics); + + planStatisticsMap.put( + finalAggregationStatsInfo.getPlanNodeWithHash(), + new PlanStatisticsWithSourceInfo( + planStatisticsWithSourceInfo.getId(), + planStatisticsFinalAgg, + planStatisticsWithSourceInfo.getSourceInfo())); + } + + private PartialAggregationStatistics constructAggregationNodeStatistics(PlanNode planNode, Map planNodeStatsMap, double outputBytes, double outputPositions) + { + PlanNode childNode = planNode.getSources().get(0); + PlanNodeStats childNodeStats = planNodeStatsMap.get(childNode.getId()); + if (childNodeStats != null) { + double partialAggregationInputBytes = adjustedOutputBytes(childNode, childNodeStats); + return new PartialAggregationStatistics(Estimate.of(partialAggregationInputBytes), + Estimate.of(outputBytes), + Estimate.of(childNodeStats.getPlanNodeOutputPositions()), + Estimate.of(outputPositions)); + } + return PartialAggregationStatistics.empty(); + } + + private boolean isAggregation(PlanNode planNode, AggregationNode.Step step) + { + return planNode instanceof AggregationNode && ((AggregationNode) planNode).getStep() == step; } // After we assign stats equivalent plan node, additional variables may be introduced by optimizer, for example @@ -205,10 +272,18 @@ private double adjustedOutputBytes(PlanNode planNode, PlanNodeStats planNodeStat outputBytes -= planNode.getOutputVariables().stream() .mapToDouble(variable -> variable.getType() instanceof FixedWidthType ? outputPositions * ((FixedWidthType) variable.getType()).getFixedSize() : 0) .sum(); - outputBytes += planNode.getStatsEquivalentPlanNode().get().getOutputVariables().stream() + // partial aggregation nodes have no stats equivalent plan node: use original output variables + List outputVariables = planNode.getOutputVariables(); + if (planNode.getStatsEquivalentPlanNode().isPresent()) { + outputVariables = planNode.getStatsEquivalentPlanNode().get().getOutputVariables(); + } + outputBytes += outputVariables.stream() .mapToDouble(variable -> variable.getType() instanceof FixedWidthType ? outputPositions * ((FixedWidthType) variable.getType()).getFixedSize() : 0) .sum(); - if (outputBytes < 0 || (outputPositions > 0 && outputBytes < 1)) { + // annotate illegal cases with NaN: if outputBytes is less than 0, or if there is at least 1 output row but less than 1 output byte + // Note that this function may be called for partial aggs that produce no output columns (e.g. "select count(*)"), where 0 is a valid number of output bytes, in this case + // the ouptput variables is an empty list + if (outputBytes < 0 || (outputPositions > 0 && outputBytes < 1 && !outputVariables.isEmpty())) { outputBytes = Double.NaN; } return outputBytes; @@ -246,4 +321,26 @@ public void updateStatistics(QueryInfo queryInfo) } historyBasedStatisticsCacheManager.invalidate(queryInfo.getQueryId()); } + + private class FinalAggregationStatsInfo + { + private final PlanNodeWithHash planNodeWithHash; + private final PlanStatisticsWithSourceInfo planStatisticsWithSourceInfo; + + FinalAggregationStatsInfo(PlanNodeWithHash planNodeWithHash, PlanStatisticsWithSourceInfo planStatisticsWithSourceInfo) + { + this.planNodeWithHash = requireNonNull(planNodeWithHash); + this.planStatisticsWithSourceInfo = requireNonNull(planStatisticsWithSourceInfo); + } + + public PlanNodeWithHash getPlanNodeWithHash() + { + return planNodeWithHash; + } + + public PlanStatisticsWithSourceInfo getPlanStatisticsWithSourceInfo() + { + return planStatisticsWithSourceInfo; + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PartialAggregationStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PartialAggregationStatsEstimate.java new file mode 100644 index 000000000000..7739bc47becc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/PartialAggregationStatsEstimate.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.lang.Double.NaN; + +public class PartialAggregationStatsEstimate +{ + private static final PartialAggregationStatsEstimate UNKNOWN = new PartialAggregationStatsEstimate(NaN, NaN, NaN, NaN); + + private final double inputBytes; + private final double outputBytes; + + private final double inputRowCount; + private final double outputRowCount; + + @JsonCreator + public PartialAggregationStatsEstimate(@JsonProperty("inputBytes") double inputBytes, @JsonProperty("outputBytes") double outputBytes, + @JsonProperty("inputRowCount") double inputRowCount, @JsonProperty("outputRowCount") double outputRowCount) + { + this.inputBytes = inputBytes; + this.outputBytes = outputBytes; + this.inputRowCount = inputRowCount; + this.outputRowCount = outputRowCount; + } + + public static PartialAggregationStatsEstimate unknown() + { + return UNKNOWN; + } + + @JsonProperty + public double getInputBytes() + { + return inputBytes; + } + + @JsonProperty + public double getOutputBytes() + { + return outputBytes; + } + + @JsonProperty + public double getInputRowCount() + { + return inputRowCount; + } + + @JsonProperty + public double getOutputRowCount() + { + return outputRowCount; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("inputBytes", inputBytes) + .add("outputBytes", outputBytes) + .add("inputRowCount", inputRowCount) + .add("outputRowCount", outputRowCount) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialAggregationStatsEstimate that = (PartialAggregationStatsEstimate) o; + return Double.compare(inputBytes, that.inputBytes) == 0 && + Double.compare(outputBytes, that.outputBytes) == 0 && + Double.compare(inputRowCount, that.inputRowCount) == 0 && + Double.compare(outputRowCount, that.outputRowCount) == 0; + } + + @Override + public int hashCode() + { + return Objects.hash(inputBytes, outputBytes, inputRowCount, outputRowCount); + } + + public static boolean isUnknown(PartialAggregationStatsEstimate partialAggregationStatsEstimate) + { + return partialAggregationStatsEstimate == PartialAggregationStatsEstimate.unknown(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index 8293de807d58..e825c7565260 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.statistics.CostBasedSourceInfo; import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.PlanStatisticsWithSourceInfo; import com.facebook.presto.spi.statistics.SourceInfo; @@ -50,7 +51,7 @@ public class PlanNodeStatsEstimate { private static final double DEFAULT_DATA_SIZE_PER_COLUMN = 50; - private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); + private static final PlanNodeStatsEstimate UNKNOWN = new PlanNodeStatsEstimate(NaN, NaN, false, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown()); private final double outputRowCount; private final double totalSize; @@ -62,6 +63,8 @@ public class PlanNodeStatsEstimate private final TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate; + private final PartialAggregationStatsEstimate partialAggregationStatsEstimate; + public static PlanNodeStatsEstimate unknown() { return UNKNOWN; @@ -74,9 +77,13 @@ public PlanNodeStatsEstimate( @JsonProperty("confident") boolean confident, @JsonProperty("variableStatistics") Map variableStatistics, @JsonProperty("joinNodeStatsEstimate") JoinNodeStatsEstimate joinNodeStatsEstimate, - @JsonProperty("tableWriterNodeStatsEstimate") TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate) + @JsonProperty("tableWriterNodeStatsEstimate") TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate, + @JsonProperty("partialAggregationStatsEstimate") PartialAggregationStatsEstimate partialAggregationStatsEstimate) { - this(outputRowCount, totalSize, HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")), new CostBasedSourceInfo(confident), joinNodeStatsEstimate, tableWriterNodeStatsEstimate); + this(outputRowCount, + totalSize, + HashTreePMap.from(requireNonNull(variableStatistics, "variableStatistics is null")), + new CostBasedSourceInfo(confident), joinNodeStatsEstimate, tableWriterNodeStatsEstimate, partialAggregationStatsEstimate); } private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean confident, PMap variableStatistics) @@ -86,11 +93,11 @@ private PlanNodeStatsEstimate(double outputRowCount, double totalSize, boolean c public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics, SourceInfo sourceInfo) { - this(outputRowCount, totalSize, variableStatistics, sourceInfo, JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); + this(outputRowCount, totalSize, variableStatistics, sourceInfo, JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown()); } public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics, SourceInfo sourceInfo, - JoinNodeStatsEstimate joinNodeStatsEstimate, TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate) + JoinNodeStatsEstimate joinNodeStatsEstimate, TableWriterNodeStatsEstimate tableWriterNodeStatsEstimate, PartialAggregationStatsEstimate partialAggregationStatsEstimate) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; @@ -99,6 +106,7 @@ public PlanNodeStatsEstimate(double outputRowCount, double totalSize, PMap variableStatistics; + private PartialAggregationStatsEstimate partialAggregationStatsEstimate; public Builder() { @@ -350,6 +371,7 @@ private Builder(double outputRowCount, double totalSize, boolean confident, PMap this.totalSize = totalSize; this.confident = confident; this.variableStatistics = variableStatistics; + this.partialAggregationStatsEstimate = PartialAggregationStatsEstimate.unknown(); } public Builder setOutputRowCount(double outputRowCount) @@ -370,6 +392,12 @@ public Builder setConfident(boolean confident) return this; } + public Builder setPartialAggregationStatsEstimate(PartialAggregationStatsEstimate partialAggregationStatsEstimate) + { + this.partialAggregationStatsEstimate = partialAggregationStatsEstimate; + return this; + } + public Builder addVariableStatistics(VariableReferenceExpression variable, VariableStatsEstimate statistics) { variableStatistics = variableStatistics.plus(variable, statistics); @@ -390,7 +418,13 @@ public Builder removeVariableStatistics(VariableReferenceExpression variable) public PlanNodeStatsEstimate build() { - return new PlanNodeStatsEstimate(outputRowCount, totalSize, confident, variableStatistics); + return new PlanNodeStatsEstimate(outputRowCount, + totalSize, + confident, + variableStatistics, + JoinNodeStatsEstimate.unknown(), + TableWriterNodeStatsEstimate.unknown(), + partialAggregationStatsEstimate); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 142486b2d73a..c9fd4a8927e7 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -284,6 +284,8 @@ public class FeaturesConfig private boolean handleComplexEquiJoins; private boolean useHBOForScaledWriters; + private boolean usePartialAggregationHistory; + private boolean removeRedundantCastToVarcharInJoin = true; public enum PartitioningPrecisionStrategy @@ -2820,6 +2822,19 @@ public FeaturesConfig setUseHBOForScaledWriters(boolean useHBOForScaledWriters) return this; } + public boolean isUsePartialAggregationHistory() + { + return this.usePartialAggregationHistory; + } + + @Config("optimizer.use-partial-aggregation-history") + @ConfigDescription("Use partial aggregation histories for splitting aggregations") + public FeaturesConfig setUsePartialAggregationHistory(boolean usePartialAggregationHistory) + { + this.usePartialAggregationHistory = usePartialAggregationHistory; + return this; + } + public boolean isRemoveRedundantCastToVarcharInJoin() { return removeRedundantCastToVarcharInJoin; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java index 1e724ba2fc59..2e47ba101d84 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CachingPlanCanonicalInfoProvider.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; @@ -160,7 +161,7 @@ private PlanStatistics getPlanStatisticsForTable(Session session, TableScanNode if (profileRuntime) { profileTime("ReadFromMetaData", startProfileTime, session); } - planStatistics = new PlanStatistics(tableStatistics.getRowCount(), tableStatistics.getTotalSize(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); + planStatistics = new PlanStatistics(tableStatistics.getRowCount(), tableStatistics.getTotalSize(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()); cache.put(key, planStatistics); return planStatistics; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 746af7070552..84578460246b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -794,7 +794,9 @@ public Optional visitAggregation(AggregationNode node, Context context .collect(toImmutableList()), node.getStep(), node.getHashVariable().map(ignored -> variableAllocator.newHashVariable()), - node.getGroupIdVariable().map(variable -> context.getExpressions().get(variable))); + node.getGroupIdVariable().map(variable -> context.getExpressions().get(variable)), + // ignore aggregationId when creating the canonical plan + Optional.empty()); context.addPlan(node, new CanonicalPlan(canonicalPlan, strategy)); return Optional.of(canonicalPlan); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 3b6bc58bf6a4..642ba7f1165e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -232,6 +232,7 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()), targetTable, variableAllocator.newVariable(getSourceLocation(analyzeStatement), "rows", BIGINT), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 8f6e1fc8e8b7..ea4c85daf428 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -264,6 +264,7 @@ public static PlanNode addAggregation(PlanNode planNode, FunctionAndTypeManager ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()), planNodeIdAllocator, variableAllocator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 55c5e0717ee3..68b7840da5f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -686,7 +686,8 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), - groupIdVariable); + groupIdVariable, + Optional.empty()); subPlan = new PlanBuilder(aggregationTranslations, aggregationNode); @@ -1126,6 +1127,7 @@ private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index ba32474b45d9..dbf53c24f900 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -1027,6 +1027,7 @@ private PlanNode distinct(PlanNode node) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index 7e1d782d43de..0cf4fb820cd4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -130,7 +130,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont aggregation.getPreGroupedVariables(), INTERMEDIATE, aggregation.getHashVariable(), - aggregation.getGroupIdVariable()); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId()); source = gatheringExchange(idAllocator.getNextId(), LOCAL, source); } @@ -174,7 +175,8 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI aggregation.getPreGroupedVariables(), INTERMEDIATE, aggregation.getHashVariable(), - aggregation.getGroupIdVariable()); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId()); } /** diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxPercentileFunctions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxPercentileFunctions.java index 0e802caa6380..c60328c00479 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxPercentileFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxPercentileFunctions.java @@ -318,7 +318,8 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable(), - aggregationNode.getGroupIdVariable()), + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId()), outputProjectAssignments.build())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 4f58c0b0fa62..a336371bf9dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -151,6 +151,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), - aggregation.getGroupIdVariable())); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeDuplicateAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeDuplicateAggregation.java index 171aa3ffdfee..6fa29126a3d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeDuplicateAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MergeDuplicateAggregation.java @@ -92,7 +92,8 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable()), + node.getGroupIdVariable(), + node.getAggregationId()), assignments.build(), LOCAL)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java index dae05afc1a9a..640ed21059da 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java @@ -168,6 +168,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) ImmutableList.of(), parent.getStep(), parent.getHashVariable(), - parent.getGroupIdVariable())); + parent.getGroupIdVariable(), + parent.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java index 41394f816020..9688d0d72d06 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneAggregationColumns.java @@ -58,6 +58,7 @@ protected Optional pushDownProjectOff( aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), - aggregationNode.getGroupIdVariable())); + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java index cd51a26093d0..9bde380771e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java @@ -87,6 +87,7 @@ else if (functionAndTypeManager.getAggregateFunctionImplementation(aggregation.g node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable())); + node.getGroupIdVariable(), + node.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java index c77894a0a99d..9cc02e3a1ab5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PullConstantsAboveGroupBy.java @@ -112,7 +112,8 @@ public Result apply(AggregationNode parent, Captures captures, Context context) ImmutableList.of(), parent.getStep(), parent.getHashVariable(), - parent.getGroupIdVariable()); + parent.getGroupIdVariable(), + parent.getAggregationId()); Map remainingVars = outputVariables.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index 46ad554745ea..e567f1900669 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -152,7 +152,8 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), - aggregation.getGroupIdVariable()); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId()); JoinNode rewrittenJoin; if (join.getType() == JoinNode.Type.LEFT) { @@ -378,6 +379,7 @@ private Optional createAggregationOverNull(AggregationNod ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 7318390f3ae0..90f79c77683d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.cost.PartialAggregationStatsEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.matching.Capture; @@ -48,6 +49,8 @@ import static com.facebook.presto.SystemSessionProperties.getPartialAggregationByteReductionThreshold; import static com.facebook.presto.SystemSessionProperties.getPartialAggregationStrategy; import static com.facebook.presto.SystemSessionProperties.isStreamingForPartialAggregationEnabled; +import static com.facebook.presto.SystemSessionProperties.usePartialAggregationHistory; +import static com.facebook.presto.cost.PartialAggregationStatsEstimate.isUnknown; import static com.facebook.presto.operator.aggregation.AggregationUtils.isDecomposable; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; @@ -126,8 +129,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context if (!decomposable || partialAggregationStrategy == NEVER || partialAggregationStrategy == AUTOMATIC && - partialAggregationNotUseful(aggregationNode, exchangeNode, context) && - aggregationNode.getGroupingKeys().size() == 1) { + partialAggregationNotUseful(aggregationNode, exchangeNode, context, aggregationNode.getGroupingKeys().size())) { return Result.empty(); } @@ -289,6 +291,7 @@ private PlanNode split(AggregationNode node, Context context) preGroupedSymbols = ImmutableList.copyOf(node.getGroupingSets().getGroupingKeys()); } + Integer aggregationId = Integer.parseInt(context.getIdAllocator().getNextId().getId()); PlanNode partial = new AggregationNode( node.getSourceLocation(), context.getIdAllocator().getNextId(), @@ -300,7 +303,8 @@ private PlanNode split(AggregationNode node, Context context) preGroupedSymbols, PARTIAL, node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + Optional.of(aggregationId)); return new AggregationNode( node.getSourceLocation(), @@ -313,20 +317,31 @@ private PlanNode split(AggregationNode node, Context context) ImmutableList.of(), FINAL, node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + Optional.of(aggregationId)); } - private boolean partialAggregationNotUseful(AggregationNode aggregationNode, ExchangeNode exchangeNode, Context context) + private boolean partialAggregationNotUseful(AggregationNode aggregationNode, ExchangeNode exchangeNode, Context context, int numAggregationKeys) { StatsProvider stats = context.getStatsProvider(); PlanNodeStatsEstimate exchangeStats = stats.getStats(exchangeNode); PlanNodeStatsEstimate aggregationStats = stats.getStats(aggregationNode); - double inputBytes = exchangeStats.getOutputSizeInBytes(exchangeNode); - double outputBytes = aggregationStats.getOutputSizeInBytes(aggregationNode); + double inputSize = exchangeStats.getOutputSizeInBytes(exchangeNode); + double outputSize = aggregationStats.getOutputSizeInBytes(aggregationNode); + PartialAggregationStatsEstimate partialAggregationStatsEstimate = aggregationStats.getPartialAggregationStatsEstimate(); + boolean isConfident = exchangeStats.isConfident(); + // keep old behavior of skipping partial aggregation only for single-key aggregations + boolean numberOfKeyCheck = usePartialAggregationHistory(context.getSession()) || numAggregationKeys == 1; + if (!isUnknown(partialAggregationStatsEstimate) && usePartialAggregationHistory(context.getSession())) { + isConfident = aggregationStats.isConfident(); + // use rows instead of bytes when use_partial_aggregation_history flag is on + inputSize = partialAggregationStatsEstimate.getInputRowCount(); + outputSize = partialAggregationStatsEstimate.getOutputRowCount(); + } double byteReductionThreshold = getPartialAggregationByteReductionThreshold(context.getSession()); // calling this function means we are using a cost-based strategy for this optimization - return exchangeStats.isConfident() && outputBytes > inputBytes * byteReductionThreshold; + return numberOfKeyCheck && isConfident && outputSize > inputSize * byteReductionThreshold; } private static boolean isLambda(RowExpression rowExpression) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java index e431ee20dd32..025ecd18b10b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java @@ -171,7 +171,8 @@ private AggregationNode replaceAggregationSource( ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), - aggregation.getGroupIdVariable()); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId()); } private PlanNode pushPartialToJoin( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantAggregateDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantAggregateDistinct.java index 2d4de68c2783..206bda115479 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantAggregateDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantAggregateDistinct.java @@ -88,6 +88,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable())); + node.getGroupIdVariable(), + node.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java index e3674ec9a738..52e5a12d57d6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java @@ -231,7 +231,8 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), - aggregationNode.getGroupIdVariable())); + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId())); } private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index 2b2261bce301..1f0f47e3e41d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -143,7 +143,8 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable())); + node.getGroupIdVariable(), + node.getAggregationId())); } private static boolean isFunctionNameMatch(RowExpression rowExpression, String expectedName) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index bdb4885a15c3..00ff288819cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -471,7 +471,8 @@ public Result apply(AggregationNode node, Captures captures, Context context) node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + node.getAggregationId()); return Result.ofPlanNode(aggregationNode); } return Result.empty(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 699133cbfe30..24dbb0942e2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -105,7 +105,8 @@ public Result apply(AggregationNode parent, Captures captures, Context context) ImmutableList.of(), parent.getStep(), parent.getHashVariable(), - parent.getGroupIdVariable())); + parent.getGroupIdVariable(), + parent.getAggregationId())); } private boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Assignments inputs) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index e59fb59c45d8..2286b6b1b19d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -138,6 +138,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()), // remove DISTINCT flag from function calls aggregation.getAggregations() @@ -149,6 +150,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont emptyList(), aggregation.getStep(), aggregation.getHashVariable(), - aggregation.getGroupIdVariable())); + aggregation.getGroupIdVariable(), + aggregation.getAggregationId())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index e30162442d63..615d4256cda6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -222,6 +222,7 @@ private PlanNode buildInPredicateEquivalent( ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToLeftEarlyOutJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToLeftEarlyOutJoin.java index 6abb575612c2..76c21eed2b7c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToLeftEarlyOutJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToLeftEarlyOutJoin.java @@ -152,7 +152,8 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), - aggregationNode.getGroupIdVariable()); + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId()); return Result.ofPlanNode(newAggregationNode); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToRightEarlyOutJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToRightEarlyOutJoin.java index f3e74d9bacd1..d1d4cc12a4c0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToRightEarlyOutJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformDistinctInnerJoinToRightEarlyOutJoin.java @@ -111,6 +111,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); JoinNode newInnerJoin = new JoinNode( @@ -136,7 +137,8 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), - aggregationNode.getGroupIdVariable()); + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId()); return Result.ofPlanNode(newDistinctNode); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index c58eef9b112f..20531aba77ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -188,6 +188,7 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()), Assignments.of(exists, comparisonExpression(functionResolution, GREATER_THAN, count, new ConstantExpression(0L, BIGINT)))), parent.getCorrelation(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin.java index b5e5525db01a..ef4abc92d84e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin.java @@ -153,6 +153,7 @@ public Result apply(ApplyNode applyNode, Captures captures, Context context) ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); ImmutableList referencedOutputs = ImmutableList.builder() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index a2e1fe1b6e46..019659096cdf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -378,7 +378,8 @@ else if (matchResult.get(0).get().getColumns().size() < groupingKeys.size() && i preGroupedSymbols, node.getStep(), node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + node.getAggregationId()); return deriveProperties(result, child.getProperties()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 72d5355e6899..07470f46fe58 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -196,7 +196,8 @@ public PlanWithProperties visitAggregation(AggregationNode node, HashComputation node.getPreGroupedVariables(), node.getStep(), hashVariable, - node.getGroupIdVariable()), + node.getGroupIdVariable(), + node.getAggregationId()), hashVariable.isPresent() ? ImmutableMap.of(groupByHash.get(), hashVariable.get()) : ImmutableMap.of()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index 88d341208b37..d19dd51f3e8e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -274,6 +274,7 @@ private AggregationNode computeCounts(UnionNode sourceNode, List context) @@ -315,7 +316,8 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + node.getAggregationId()); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 05e8ade5787a..b3db2189fbfb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -238,7 +238,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext varMap = new HashMap<>(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 4e9748206243..c5f346b7b6b3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -185,6 +185,7 @@ public Optional visitLimit(LimitNode node, Void context) ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); return Optional.of(new DecorrelationResult( @@ -241,7 +242,8 @@ public Optional visitAggregation(AggregationNode node, Void ImmutableList.of(), decorrelatedAggregation.getStep(), decorrelatedAggregation.getHashVariable(), - decorrelatedAggregation.getGroupIdVariable()); + decorrelatedAggregation.getGroupIdVariable(), + decorrelatedAggregation.getAggregationId()); boolean atMostSingleRow = newAggregation.getGroupingSetCount() == 1 && constantVariables.containsAll(newAggregation.getGroupingKeys()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index dd49a07571bc..d5219e4fce33 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -1641,7 +1641,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext createAggregationNode( ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashVariable(), + Optional.empty(), Optional.empty())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java index 081f4a30d736..c5f6e02d0c59 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SetFlatteningOptimizer.java @@ -155,7 +155,8 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext c ImmutableList.of(), node.getStep(), node.getHashVariable(), - node.getGroupIdVariable()); + node.getGroupIdVariable(), + node.getAggregationId()); } private static boolean isDistinctOperator(AggregationNode node) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 2961d28d0ece..099cd97a1f1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -184,7 +184,8 @@ private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId ne mapAndDistinctVariable(node.getPreGroupedVariables()), node.getStep(), node.getHashVariable().map(this::map), - node.getGroupIdVariable().map(this::map)); + node.getGroupIdVariable().map(this::map), + node.getAggregationId()); } private Aggregation map(Aggregation aggregation) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index fb3cbd2b574f..d402cf81cb42 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -202,6 +202,7 @@ countNonNullValue, new Aggregation( ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); PlanNode lateralJoinNode = new LateralJoinNode( diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 68336bc5238d..680229bbb324 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -811,6 +811,7 @@ private AggregationNode aggregation(String id, PlanNode source) ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), + Optional.empty(), Optional.empty()); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java index 263f59ad29d9..63cc56f3cbff 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFragmentStatsProvider.java @@ -32,8 +32,8 @@ public void testFragmentStatsProvider() QueryId queryId2 = new QueryId("queryid2"); PlanFragmentId planFragmentId1 = new PlanFragmentId(1); PlanFragmentId planFragmentId2 = new PlanFragmentId(2); - PlanNodeStatsEstimate planNodeStatsEstimate1 = new PlanNodeStatsEstimate(NaN, 10, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); - PlanNodeStatsEstimate planNodeStatsEstimate2 = new PlanNodeStatsEstimate(NaN, 100, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown()); + PlanNodeStatsEstimate planNodeStatsEstimate1 = new PlanNodeStatsEstimate(NaN, 10, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown()); + PlanNodeStatsEstimate planNodeStatsEstimate2 = new PlanNodeStatsEstimate(NaN, 100, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown()); assertEquals(fragmentStatsProvider.getStats(queryId1, planFragmentId1), PlanNodeStatsEstimate.unknown()); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java index 05cee3a47d32..a64923a883c4 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoricalPlanStatistics.java @@ -16,6 +16,7 @@ import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.google.common.collect.ImmutableList; @@ -84,7 +85,7 @@ public void testMaxStatistics() private PlanStatistics stats(double rows, double size) { - return new PlanStatistics(Estimate.of(rows), Estimate.of(size), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); + return new PlanStatistics(Estimate.of(rows), Estimate.of(size), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()); } private static HistoricalPlanStatistics updatePlanStatistics( diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java index 9b650d1427da..1f595ea74ce9 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistoryBasedStatsProvider.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.statistics.HistoricalPlanStatisticsEntry; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.Optimizer; @@ -124,8 +125,8 @@ public Map getStats(List planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId(1), new PlanFragmentId(2)))) .check(check -> check.totalSize(2000) .outputRowsCountUnknown()); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java b/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java index 642aff3455fc..4f564613bf85 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestDriver.java @@ -120,6 +120,7 @@ public class TestDriver ImmutableList.of(), AggregationNode.Step.PARTIAL, Optional.empty(), + Optional.empty(), Optional.empty()), new PartitioningScheme(Partitioning.create(FIXED_HASH_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), testSessionBuilder().setSystemProperty(FRAGMENT_RESULT_CACHING_ENABLED, "true").build(), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctionUtils.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctionUtils.java index 84b11445309a..0c77a04d046e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctionUtils.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctionUtils.java @@ -27,7 +27,7 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " \"identifier\" : \"[a, b, a, b]\",\n" + " \"details\" : \"b := b_1 (1:41)\\n\",\n" + " \"children\" : [ {\n" + - " \"id\" : \"241\",\n" + + " \"id\" : \"253\",\n" + " \"name\" : \"RemoteSource\",\n" + " \"identifier\" : \"[1]\",\n" + " \"details\" : \"\",\n" + @@ -65,22 +65,28 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " } ]\n" + " }\n" + " },\n" + " \"1\" : {\n" + " \"plan\" : {\n" + - " \"id\" : \"218\",\n" + + " \"id\" : \"230\",\n" + " \"name\" : \"InnerJoin\",\n" + " \"identifier\" : \"[(\\\"a\\\" = \\\"a_0\\\")][$hashvalue, $hashvalue_21]\",\n" + " \"details\" : \"Distribution: PARTITIONED\\n\",\n" + " \"children\" : [ {\n" + - " \"id\" : \"239\",\n" + + " \"id\" : \"251\",\n" + " \"name\" : \"RemoteSource\",\n" + " \"identifier\" : \"[2]\",\n" + " \"details\" : \"\",\n" + @@ -88,12 +94,12 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " \"remoteSources\" : [ \"2\" ],\n" + " \"estimates\" : [ ]\n" + " }, {\n" + - " \"id\" : \"272\",\n" + + " \"id\" : \"284\",\n" + " \"name\" : \"LocalExchange\",\n" + " \"identifier\" : \"[HASH][$hashvalue_21] (a_0)\",\n" + " \"details\" : \"\",\n" + " \"children\" : [ {\n" + - " \"id\" : \"240\",\n" + + " \"id\" : \"252\",\n" + " \"name\" : \"RemoteSource\",\n" + " \"identifier\" : \"[3]\",\n" + " \"details\" : \"\",\n" + @@ -131,10 +137,16 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " } ]\n" + " } ],\n" + @@ -168,20 +180,26 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " } ]\n" + " }\n" + " },\n" + " \"2\" : {\n" + " \"plan\" : {\n" + - " \"id\" : \"301\",\n" + + " \"id\" : \"313\",\n" + " \"name\" : \"ScanProject\",\n" + " \"identifier\" : \"[table = TableHandle {connectorId='hive', connectorHandle='HiveTableHandle{schemaName=tpch, tableName=r, analyzePartitionValues=Optional.empty}', layout='Optional[tpch.r{}]'}, projectLocality = LOCAL]\",\n" + - " \"details\" : \"$hashvalue_20 := combine_hash(BIGINT'0', COALESCE($operator$hash_code(a), BIGINT'0')) (1:55)\\nLAYOUT: tpch.r{}\\na := a:int:0:REGULAR (1:55)\\nb := b:int:1:REGULAR (1:55)\\n\",\n" + + " \"details\" : \"$hashvalue_20 := combine_hash(BIGINT'0', COALESCE($operator$hash_code(a), BIGINT'0')) (1:55)\\nLAYOUT: tpch.r{}\\nb := b:int:1:REGULAR (1:55)\\na := a:int:0:REGULAR (1:55)\\n\",\n" + " \"children\" : [ ],\n" + " \"remoteSources\" : [ ],\n" + " \"estimates\" : [ {\n" + @@ -206,10 +224,16 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " }, {\n" + " \"outputRowCount\" : 0.0,\n" + @@ -240,17 +264,23 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " } ]\n" + " }\n" + " },\n" + " \"3\" : {\n" + " \"plan\" : {\n" + - " \"id\" : \"302\",\n" + + " \"id\" : \"314\",\n" + " \"name\" : \"ScanProject\",\n" + " \"identifier\" : \"[table = TableHandle {connectorId='hive', connectorHandle='HiveTableHandle{schemaName=tpch, tableName=s, analyzePartitionValues=Optional.empty}', layout='Optional[tpch.s{}]'}, projectLocality = LOCAL]\",\n" + " \"details\" : \"$hashvalue_23 := combine_hash(BIGINT'0', COALESCE($operator$hash_code(a_0), BIGINT'0')) (1:57)\\nLAYOUT: tpch.s{}\\nb_1 := b:int:1:REGULAR (1:57)\\na_0 := a:int:0:REGULAR (1:57)\\n\",\n" + @@ -278,10 +308,16 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " }, {\n" + " \"outputRowCount\" : 0.0,\n" + @@ -312,10 +348,16 @@ private TestJsonPrestoQueryPlanFunctionUtils() {} " },\n" + " \"joinNodeStatsEstimate\" : {\n" + " \"nullJoinBuildKeyCount\" : \"NaN\",\n" + - " \"joinBuildKeyCount\" : \"NaN\"\n" + + " \"joinBuildKeyCount\" : \"NaN\",\n" + + " \"nullJoinProbeKeyCount\" : \"NaN\",\n" + + " \"joinProbeKeyCount\" : \"NaN\"\n" + " },\n" + " \"tableWriterNodeStatsEstimate\" : {\n" + " \"taskCountIfScaledWriter\" : \"NaN\"\n" + + " },\n" + + " \"partialAggregationStatsEstimate\" : {\n" + + " \"inputBytes\" : \"NaN\",\n" + + " \"outputBytes\" : \"NaN\"\n" + " }\n" + " } ]\n" + " }\n" + diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctions.java index 0875a62ed4e8..c7ccc29758f6 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/queryplan/TestJsonPrestoQueryPlanFunctions.java @@ -31,7 +31,7 @@ public void testJsonPlanIds() assertFunction("json_presto_query_plan_ids(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "')", - new ArrayType(VARCHAR), ImmutableList.of("301", "302", "8", "239", "218", "272", "240", "241")); + new ArrayType(VARCHAR), ImmutableList.of("253", "313", "314", "8", "251", "284", "230", "252")); } @Test @@ -40,11 +40,11 @@ public void testJsonPlanNodeChildren() assertFunction("json_presto_query_plan_node_children(null, null)", new ArrayType(VARCHAR), null); assertFunction("json_presto_query_plan_node_children(null, '1')", new ArrayType(VARCHAR), null); - assertFunction("json_presto_query_plan_node_children(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "', '301')", + assertFunction("json_presto_query_plan_node_children(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "', '314')", new ArrayType(VARCHAR), ImmutableList.of()); - assertFunction("json_presto_query_plan_node_children(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "', '218')", - new ArrayType(VARCHAR), ImmutableList.of("239", "272")); + assertFunction("json_presto_query_plan_node_children(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "', '230')", + new ArrayType(VARCHAR), ImmutableList.of("251", "284")); assertFunction("json_presto_query_plan_node_children(json '" + TestJsonPrestoQueryPlanFunctionUtils.joinPlan.replaceAll("'", "''") + "', 'nonkey')", new ArrayType(VARCHAR), null); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index b73297778d20..478461ff9383 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -79,6 +79,7 @@ public void testDefaults() .setMaxReorderedJoins(9) .setUseHistoryBasedPlanStatistics(false) .setTrackHistoryBasedPlanStatistics(false) + .setUsePartialAggregationHistory(false) .setUsePerfectlyConsistentHistories(false) .setHistoryCanonicalPlanNodeLimit(1000) .setHistoryBasedOptimizerTimeout(new Duration(10, SECONDS)) @@ -297,6 +298,7 @@ public void testExplicitPropertyMappings() .put("optimizer.max-reordered-joins", "5") .put("optimizer.use-history-based-plan-statistics", "true") .put("optimizer.track-history-based-plan-statistics", "true") + .put("optimizer.use-partial-aggregation-history", "true") .put("optimizer.use-perfectly-consistent-histories", "true") .put("optimizer.history-canonical-plan-node-limit", "2") .put("optimizer.history-based-optimizer-timeout", "1s") @@ -482,6 +484,7 @@ public void testExplicitPropertyMappings() .setMaxReorderedJoins(5) .setUseHistoryBasedPlanStatistics(true) .setTrackHistoryBasedPlanStatistics(true) + .setUsePartialAggregationHistory(true) .setUsePerfectlyConsistentHistories(true) .setHistoryCanonicalPlanNodeLimit(2) .setHistoryBasedOptimizerTimeout(new Duration(1, SECONDS)) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 4e0825819024..6210aaf82d21 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -167,6 +167,7 @@ DV, count(metadata.getFunctionAndTypeManager())), ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), + Optional.empty(), Optional.empty()); RowExpression effectivePredicate = effectivePredicateExtractor.extract(node); @@ -192,6 +193,7 @@ public void testGroupByEmpty() ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), + Optional.empty(), Optional.empty()); RowExpression effectivePredicate = effectivePredicateExtractor.extract(node); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 5cba4c14e2ab..548805cc3a42 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -212,6 +212,7 @@ public void testValidAggregation() ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -239,6 +240,7 @@ public void testValidIntermediateAggregation() ImmutableList.of(), INTERMEDIATE, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -266,6 +268,7 @@ public void testValidPartialAggregation() ImmutableList.of(), PARTIAL, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -323,6 +326,7 @@ public void testInvalidIntermediateAggregationReturnType() ImmutableList.of(), INTERMEDIATE, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -350,6 +354,7 @@ public void testInvalidPartialAggregationReturnType() ImmutableList.of(), PARTIAL, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -378,6 +383,7 @@ public void testInvalidAggregationFunctionCall() ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); @@ -406,6 +412,7 @@ public void testInvalidAggregationFunctionSignature() ImmutableList.of(), SINGLE, Optional.empty(), + Optional.empty(), Optional.empty()); assertTypesValid(node); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java index ea0d1de9beb9..18d973cdf60f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushPartialAggregationThroughExchange.java @@ -13,16 +13,20 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.cost.PartialAggregationStatsEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.VariableStatsEstimate; +import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import static com.facebook.presto.SystemSessionProperties.PARTIAL_AGGREGATION_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.USE_PARTIAL_AGGREGATION_HISTORY; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; @@ -107,6 +111,71 @@ public void testNoPartialAggregationWhenReductionBelowThreshold() .doesNotFire(); } + @Test + public void testNoPartialAggregationWhenReductionBelowThresholdUsingPartialAggregationStats() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true") + .on(p -> constructAggregation(p)) + .overrideStats("aggregation", PlanNodeStatsEstimate.builder() + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(true) + .setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000, 800, 10, 10)) + .build()) + .doesNotFire(); + } + + @Test + public void testNoPartialAggregationWhenReductionAboveThresholdUsingPartialAggregationStats() + { + // when use_partial_aggregation_history=true, we use row count reduction (instead of bytes) to decide if partial aggregation is useful + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true") + .on(p -> constructAggregation(p)) + .overrideStats("aggregation", PlanNodeStatsEstimate.builder() + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(true) + .setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(1000, 300, 10, 10)) + .build()) + .doesNotFire(); + } + + @Test + public void testNoPartialAggregationWhenRowReductionBelowThreshold() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true") + .on(p -> constructAggregation(p)) + .overrideStats("aggregation", PlanNodeStatsEstimate.builder() + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(true) + .setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0, 300, 10, 8)) + .build()) + .doesNotFire(); + } + + @Test + public void testPartialAggregationWhenRowReductionAboveThreshold() + { + tester().assertThat(new PushPartialAggregationThroughExchange(getFunctionManager())) + .setSystemProperty(PARTIAL_AGGREGATION_STRATEGY, "AUTOMATIC") + .setSystemProperty(USE_PARTIAL_AGGREGATION_HISTORY, "true") + .on(p -> constructAggregation(p)) + .overrideStats("aggregation", PlanNodeStatsEstimate.builder() + .addVariableStatistics(variable("b", DOUBLE), new VariableStatsEstimate(0, 100, 0, 8, 800)) + .setConfident(true) + .setPartialAggregationStatsEstimate(new PartialAggregationStatsEstimate(0, 300, 10, 1)) + .build()) + .matches(aggregation(ImmutableMap.of("sum", functionCall("sum", ImmutableList.of("sum0"))), + aggregation( + ImmutableMap.of("sum0", functionCall("sum", ImmutableList.of("a"))), + exchange( + values("a", "b"))))); + } + @Test public void testPartialAggregationEnabledWhenNotConfident() { @@ -137,4 +206,20 @@ public void testPartialAggregationEnabledWhenNotConfident() PARTIAL, values("a", "b"))))); } + + private static AggregationNode constructAggregation(PlanBuilder p) + { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression b = p.variable("b", DOUBLE); + return p.aggregation(ab -> ab + .source( + p.exchange(e -> e + .addSource(p.values(new PlanNodeId("values"), a, b)) + .addInputsSet(a, b) + .singleDistributionPartitioningScheme( + ImmutableList.of(a, b)))) + .addAggregation(p.variable("sum", DOUBLE), p.rowExpression("sum(a)")) + .singleGroupingSet(b) + .setPlanNodeId(new PlanNodeId("aggregation"))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 28e8cf750667..c44e5a2b2e2c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -358,6 +358,7 @@ public class AggregationBuilder { private final TypeProvider types; private PlanNode source; + private PlanNodeId planNodeId; // Preserve order when creating assignments, so it's consistent when printed/iterated. Some // optimizations create variable names by iterating over it, and this will make plan more consistent // in future runs. @@ -380,6 +381,12 @@ public AggregationBuilder source(PlanNode source) return this; } + public AggregationBuilder setPlanNodeId(PlanNodeId planNodeId) + { + this.planNodeId = planNodeId; + return this; + } + public AggregationBuilder addAggregation(VariableReferenceExpression output, RowExpression expression) { return addAggregation(output, expression, false); @@ -463,14 +470,15 @@ protected AggregationNode build() checkState(groupingSets != null, "No grouping sets defined; use globalGrouping/groupingKeys method"); return new AggregationNode( source.getSourceLocation(), - idAllocator.getNextId(), + planNodeId == null ? idAllocator.getNextId() : planNodeId, source, assignments, groupingSets, preGroupedVariables, step, hashVariable, - groupIdVariable); + groupIdVariable, + Optional.empty()); } } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java index b37ee289f1bd..97476ec87714 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestIterativePlanFragmenter.java @@ -328,6 +328,7 @@ private AggregationNode aggregation(String id, PlanNode source) ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), + Optional.empty(), Optional.empty()); } diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java index f44759eced9d..505e5a3f9fb8 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/TestPrestoSparkStatsCalculator.java @@ -19,6 +19,7 @@ import com.facebook.presto.cost.HistoryBasedOptimizationConfig; import com.facebook.presto.cost.HistoryBasedPlanStatisticsCalculator; import com.facebook.presto.cost.JoinNodeStatsEstimate; +import com.facebook.presto.cost.PartialAggregationStatsEstimate; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsCalculatorTester; import com.facebook.presto.cost.TableWriterNodeStatsEstimate; @@ -33,6 +34,7 @@ import com.facebook.presto.spi.statistics.HistoricalPlanStatisticsEntry; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; @@ -110,7 +112,7 @@ public void resetCaches() @Test public void testUsesHboStatsWhenMatchRuntime() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); PlanNode statsEquivalentRemoteSource = planBuilder .registerVariable(planBuilder.variable("c1")) @@ -126,7 +128,7 @@ public void testUsesHboStatsWhenMatchRuntime() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)), statsEquivalentRemoteSource)) @@ -137,7 +139,7 @@ public void testUsesHboStatsWhenMatchRuntime() @Test public void testUsesRuntimeStatsWhenNoHboStats() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown())); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) .check(check -> check.totalSize(1000) .outputRowsCountUnknown()); @@ -160,7 +162,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() StatsCalculatorTester tester = new StatsCalculatorTester( localQueryRunner, prestoSparkStatsCalculator); - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), localQueryRunner.getMetadata()); PlanNode statsEquivalentRemoteSource = planBuilder @@ -175,7 +177,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) @@ -187,7 +189,7 @@ public void testUsesRuntimeStatsWhenHboDisabled() @Test public void testUsesRuntimeStatsWhenDiffersFromHbo() { - fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown())); + fragmentStatsProvider.putStats(TEST_QUERY_ID, new PlanFragmentId(1), new PlanNodeStatsEstimate(NaN, 1000, true, ImmutableMap.of(), JoinNodeStatsEstimate.unknown(), TableWriterNodeStatsEstimate.unknown(), PartialAggregationStatsEstimate.unknown())); PlanBuilder planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); PlanNode statsEquivalentRemoteSource = planBuilder @@ -202,7 +204,7 @@ public void testUsesRuntimeStatsWhenDiffersFromHbo() new HistoricalPlanStatistics( ImmutableList.of( new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(10), Estimate.of(100), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + new PlanStatistics(Estimate.of(10), Estimate.of(100), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), ImmutableList.of()))))); tester.assertStatsFor(pb -> pb.remoteSource(ImmutableList.of(new PlanFragmentId(1)))) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/plan/AggregationNode.java b/presto-spi/src/main/java/com/facebook/presto/spi/plan/AggregationNode.java index ab97034924a7..81c1505c4a7c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/plan/AggregationNode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/plan/AggregationNode.java @@ -33,6 +33,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; import static java.util.Collections.emptyList; import static java.util.Collections.emptySet; @@ -52,6 +53,7 @@ public final class AggregationNode private final Step step; private final Optional hashVariable; private final Optional groupIdVariable; + private final Optional aggregationId; private final List outputs; @JsonCreator @@ -64,9 +66,10 @@ public AggregationNode( @JsonProperty("preGroupedVariables") List preGroupedVariables, @JsonProperty("step") Step step, @JsonProperty("hashVariable") Optional hashVariable, - @JsonProperty("groupIdVariable") Optional groupIdVariable) + @JsonProperty("groupIdVariable") Optional groupIdVariable, + @JsonProperty("aggregationId")Optional aggregationId) { - this(sourceLocation, id, Optional.empty(), source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable); + this(sourceLocation, id, Optional.empty(), source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable, aggregationId); } public AggregationNode( @@ -79,7 +82,8 @@ public AggregationNode( List preGroupedVariables, Step step, Optional hashVariable, - Optional groupIdVariable) + Optional groupIdVariable, + Optional aggregationId) { super(sourceLocation, id, statsEquivalentPlanNode); @@ -91,6 +95,7 @@ public AggregationNode( this.groupingSets = groupingSets; this.groupIdVariable = requireNonNull(groupIdVariable); + this.aggregationId = requireNonNull(aggregationId); boolean noOrderBy = aggregations.values().stream() .map(Aggregation::getOrderBy) @@ -220,6 +225,12 @@ public Optional getGroupIdVariable() return groupIdVariable; } + @JsonProperty + public Optional getAggregationId() + { + return aggregationId; + } + public boolean hasOrderings() { return aggregations.values().stream() @@ -236,14 +247,14 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new AggregationNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable); + return new AggregationNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable, aggregationId); } @Override public PlanNode replaceChildren(List newChildren) { checkArgument(newChildren.size() == 1, "Unexpected number of elements in list newChildren"); - return new AggregationNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable); + return new AggregationNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), newChildren.get(0), aggregations, groupingSets, preGroupedVariables, step, hashVariable, groupIdVariable, aggregationId); } public boolean isStreamable() diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PartialAggregationStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PartialAggregationStatistics.java new file mode 100644 index 000000000000..383b55e8fff1 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PartialAggregationStatistics.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.statistics; + +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public class PartialAggregationStatistics +{ + private static final PartialAggregationStatistics EMPTY = new PartialAggregationStatistics(Estimate.unknown(), Estimate.unknown(), Estimate.unknown(), Estimate.unknown()); + // Number of input bytes + private final Estimate partialAggregationInputBytes; + // Number of output bytes + private final Estimate partialAggregationOutputBytes; + + private final Estimate partialAggregationInputRows; + + private final Estimate partialAggregationOutputRows; + + @JsonCreator + @ThriftConstructor + public PartialAggregationStatistics( + @JsonProperty("partialAggregationInputBytes") Estimate partialAggregationInputBytes, + @JsonProperty("partialAggregationOutputBytes") Estimate partialAggregationOutputBytes, + @JsonProperty("partialAggregationInputRows") Estimate partialAggregationInputRows, + @JsonProperty("partialAggregationOutputRows") Estimate partialAggregationOutputRows) + { + this.partialAggregationInputBytes = requireNonNull(partialAggregationInputBytes, "partialAggregationInputBytes is null"); + this.partialAggregationOutputBytes = requireNonNull(partialAggregationOutputBytes, "partialAggregationOutputBytes is null"); + this.partialAggregationInputRows = requireNonNull(partialAggregationInputRows, "partialAggregationInputRows is null"); + this.partialAggregationOutputRows = requireNonNull(partialAggregationOutputRows, "partialAggregationOutputRows is null"); + } + + public static PartialAggregationStatistics empty() + { + return EMPTY; + } + + public boolean isEmpty() + { + return this.equals(empty()); + } + + @JsonProperty + @ThriftField(1) + public Estimate getPartialAggregationInputBytes() + { + return partialAggregationInputBytes; + } + + @JsonProperty + @ThriftField(2) + public Estimate getPartialAggregationOutputBytes() + { + return partialAggregationOutputBytes; + } + + @JsonProperty + @ThriftField(3) + public Estimate getPartialAggregationInputRows() + { + return partialAggregationInputRows; + } + + @JsonProperty + @ThriftField(4) + public Estimate getPartialAggregationOutputRows() + { + return partialAggregationOutputRows; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialAggregationStatistics that = (PartialAggregationStatistics) o; + return Objects.equals(partialAggregationInputBytes, that.partialAggregationInputBytes) && + Objects.equals(partialAggregationOutputBytes, that.partialAggregationOutputBytes) && + Objects.equals(partialAggregationInputRows, that.partialAggregationInputRows) && + Objects.equals(partialAggregationOutputRows, that.partialAggregationOutputRows); + } + + @Override + public int hashCode() + { + return Objects.hash(partialAggregationInputBytes, partialAggregationOutputBytes, partialAggregationInputRows, partialAggregationOutputRows); + } + + @Override + public String toString() + { + return "PartialAggregationStatistics{" + + "partialAggregationInputBytes=" + partialAggregationInputBytes + + ", partialAggregationOutputBytes=" + partialAggregationOutputBytes + + "partialAggregationInputRows=" + partialAggregationInputRows + + ", partialAggregationOutputRows=" + partialAggregationOutputRows + + '}'; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java index a5e365fedcbf..cf7934d591f1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/PlanStatistics.java @@ -27,7 +27,7 @@ @ThriftStruct public class PlanStatistics { - private static final PlanStatistics EMPTY = new PlanStatistics(Estimate.unknown(), Estimate.unknown(), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()); + private static final PlanStatistics EMPTY = new PlanStatistics(Estimate.unknown(), Estimate.unknown(), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()); private final Estimate rowCount; private final Estimate outputSize; @@ -37,6 +37,8 @@ public class PlanStatistics private final JoinNodeStatistics joinNodeStatistics; // TableWriter node specific statistics private final TableWriterNodeStatistics tableWriterNodeStatistics; + // Aggregation node specific statistics + private final PartialAggregationStatistics partialAggregationStatistics; public static PlanStatistics empty() { @@ -49,7 +51,8 @@ public PlanStatistics(@JsonProperty("rowCount") Estimate rowCount, @JsonProperty("outputSize") Estimate outputSize, @JsonProperty("confidence") double confidence, @JsonProperty("joinNodeStatistics") JoinNodeStatistics joinNodeStatistics, - @JsonProperty("tableWriterNodeStatistics") TableWriterNodeStatistics tableWriterNodeStatistics) + @JsonProperty("tableWriterNodeStatistics") TableWriterNodeStatistics tableWriterNodeStatistics, + @JsonProperty("partialAggregationStatistics") PartialAggregationStatistics partialAggregationStatistics) { this.rowCount = requireNonNull(rowCount, "rowCount is null"); this.outputSize = requireNonNull(outputSize, "outputSize is null"); @@ -57,6 +60,7 @@ public PlanStatistics(@JsonProperty("rowCount") Estimate rowCount, this.confidence = confidence; this.joinNodeStatistics = requireNonNull(joinNodeStatistics == null ? JoinNodeStatistics.empty() : joinNodeStatistics, "joinNodeStatistics is null"); this.tableWriterNodeStatistics = requireNonNull(tableWriterNodeStatistics == null ? TableWriterNodeStatistics.empty() : tableWriterNodeStatistics, "tableWriterNodeStatistics is null"); + this.partialAggregationStatistics = requireNonNull(partialAggregationStatistics == null ? PartialAggregationStatistics.empty() : partialAggregationStatistics, "partialAggregationStatistics is null"); } @JsonProperty @@ -94,7 +98,14 @@ public TableWriterNodeStatistics getTableWriterNodeStatistics() return tableWriterNodeStatistics; } - // Next ThriftField value 8 + @JsonProperty + @ThriftField(value = 8, requiredness = OPTIONAL) + public PartialAggregationStatistics getPartialAggregationStatistics() + { + return partialAggregationStatistics; + } + + // Next ThriftField value 9 public PlanStatistics update(PlanStatistics planStatistics) { @@ -102,7 +113,18 @@ public PlanStatistics update(PlanStatistics planStatistics) planStatistics.getOutputSize(), planStatistics.getConfidence(), planStatistics.getJoinNodeStatistics().isEmpty() ? getJoinNodeStatistics() : planStatistics.getJoinNodeStatistics(), - planStatistics.getTableWriterNodeStatistics().isEmpty() ? getTableWriterNodeStatistics() : planStatistics.getTableWriterNodeStatistics()); + planStatistics.getTableWriterNodeStatistics().isEmpty() ? getTableWriterNodeStatistics() : planStatistics.getTableWriterNodeStatistics(), + planStatistics.getPartialAggregationStatistics().isEmpty() ? getPartialAggregationStatistics() : planStatistics.getPartialAggregationStatistics()); + } + + public PlanStatistics updateAggregationStatistics(PartialAggregationStatistics partialAggregationStatistics) + { + return new PlanStatistics(getRowCount(), + getOutputSize(), + getConfidence(), + getJoinNodeStatistics(), + getTableWriterNodeStatistics(), + partialAggregationStatistics); } private static void checkArgument(boolean condition, String message) @@ -123,13 +145,14 @@ public boolean equals(Object o) } PlanStatistics that = (PlanStatistics) o; return Double.compare(that.confidence, confidence) == 0 && Objects.equals(rowCount, that.rowCount) && Objects.equals(outputSize, that.outputSize) - && Objects.equals(joinNodeStatistics, that.joinNodeStatistics) && Objects.equals(tableWriterNodeStatistics, that.tableWriterNodeStatistics); + && Objects.equals(joinNodeStatistics, that.joinNodeStatistics) && Objects.equals(tableWriterNodeStatistics, that.tableWriterNodeStatistics) + && Objects.equals(partialAggregationStatistics, that.partialAggregationStatistics); } @Override public int hashCode() { - return Objects.hash(rowCount, outputSize, confidence, joinNodeStatistics, tableWriterNodeStatistics); + return Objects.hash(rowCount, outputSize, confidence, joinNodeStatistics, tableWriterNodeStatistics, partialAggregationStatistics); } @Override @@ -141,6 +164,7 @@ public String toString() ", confidence=" + confidence + ", joinNodeStatistics=" + joinNodeStatistics + ", tableWriterNodeStatistics=" + tableWriterNodeStatistics + + ", partialAggregationStatistics=" + partialAggregationStatistics + '}'; } } diff --git a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java index def878e92795..8d60f723ff38 100644 --- a/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java +++ b/redis-hbo-provider/src/test/java/com/facebook/presto/statistic/TestHistoricalStatisticsSerde.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.statistics.HistoricalPlanStatistics; import com.facebook.presto.spi.statistics.HistoricalPlanStatisticsEntry; import com.facebook.presto.spi.statistics.JoinNodeStatistics; +import com.facebook.presto.spi.statistics.PartialAggregationStatistics; import com.facebook.presto.spi.statistics.PlanStatistics; import com.facebook.presto.spi.statistics.TableWriterNodeStatistics; import com.google.common.collect.ImmutableList; @@ -37,8 +38,8 @@ public class TestHistoricalStatisticsSerde public void testSimpleHistoricalStatisticsEncoderDecoder() { HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(ImmutableList.of(new HistoricalPlanStatisticsEntry( - new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), - ImmutableList.of(new PlanStatistics(Estimate.of(15000), Estimate.unknown(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()))))); + new PlanStatistics(Estimate.of(100), Estimate.of(1000), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), + ImmutableList.of(new PlanStatistics(Estimate.of(15000), Estimate.unknown(), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()))))); HistoricalStatisticsSerde historicalStatisticsEncoderDecoder = new HistoricalStatisticsSerde(); // Test PlanHash @@ -55,8 +56,8 @@ public void testHistoricalPlanStatisticsEntryList() { List historicalPlanStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), - ImmutableList.of(new PlanStatistics(Estimate.of(100), Estimate.of(i), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty())))); + historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), + ImmutableList.of(new PlanStatistics(Estimate.of(100), Estimate.of(i), 0, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty())))); } HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(historicalPlanStatisticsEntryList); HistoricalStatisticsSerde historicalStatisticsEncoderDecoder = new HistoricalStatisticsSerde(); @@ -82,11 +83,11 @@ public void testPlanStatisticsList() { List planStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - planStatisticsEntryList.add(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty())); + planStatisticsEntryList.add(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty())); } List historicalPlanStatisticsEntryList = new ArrayList<>(); for (int i = 0; i < 50; i++) { - historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty()), + historicalPlanStatisticsEntryList.add(new HistoricalPlanStatisticsEntry(new PlanStatistics(Estimate.of(i * 5), Estimate.of(i * 5), 1, JoinNodeStatistics.empty(), TableWriterNodeStatistics.empty(), PartialAggregationStatistics.empty()), planStatisticsEntryList)); } HistoricalPlanStatistics samplePlanStatistics = new HistoricalPlanStatistics(historicalPlanStatisticsEntryList);