diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/BasicOperatorStats.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/BasicOperatorStats.java index ea8bee223d96..822b224fc17a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/BasicOperatorStats.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/BasicOperatorStats.java @@ -15,6 +15,8 @@ import io.trino.spi.metrics.Metrics; +import java.util.List; + import static java.util.Objects.requireNonNull; class BasicOperatorStats @@ -73,4 +75,21 @@ public static BasicOperatorStats merge(BasicOperatorStats first, BasicOperatorSt first.metrics.mergeWith(second.metrics), first.connectorMetrics.mergeWith(second.connectorMetrics)); } + + public static BasicOperatorStats merge(List operatorStats) + { + long totalDrivers = 0; + long inputPositions = 0; + double sumSquaredInputPositions = 0; + Metrics.Accumulator metricsAccumulator = Metrics.accumulator(); + Metrics.Accumulator connectorMetricsAccumulator = Metrics.accumulator(); + for (BasicOperatorStats stats : operatorStats) { + totalDrivers += stats.totalDrivers; + inputPositions += stats.inputPositions; + sumSquaredInputPositions += stats.sumSquaredInputPositions; + metricsAccumulator.add(stats.metrics); + connectorMetricsAccumulator.add(stats.connectorMetrics); + } + return new BasicOperatorStats(totalDrivers, inputPositions, sumSquaredInputPositions, metricsAccumulator.get(), connectorMetricsAccumulator.get()); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStats.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStats.java index d83f680a7c4d..1c2e05ab7436 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStats.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStats.java @@ -13,12 +13,15 @@ */ package io.trino.sql.planner.planprinter; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.spi.Mergeable; import io.trino.sql.planner.plan.PlanNodeId; +import java.util.List; import java.util.Map; import java.util.Set; @@ -30,6 +33,7 @@ import static java.lang.Math.sqrt; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; public class PlanNodeStats implements Mergeable @@ -194,4 +198,59 @@ public PlanNodeStats mergeWith(PlanNodeStats other) succinctBytes(this.planNodeSpilledDataSize.toBytes() + other.planNodeSpilledDataSize.toBytes()), operatorStats); } + + @Override + public PlanNodeStats mergeWith(List others) + { + long planNodeInputPositions = this.planNodeInputPositions; + long planNodeOutputPositions = this.planNodeOutputPositions; + long planNodeInputDataSizeBytes = planNodeInputDataSize.toBytes(); + long planNodeOutputDataSizeBytes = planNodeOutputDataSize.toBytes(); + long planNodePhysicalInputDataSizeBytes = planNodePhysicalInputDataSize.toBytes(); + long planNodeSpilledDataSizeBytes = planNodeSpilledDataSize.toBytes(); + long planNodeScheduledTimeMillis = planNodeScheduledTime.toMillis(); + long planNodeCpuTimeMillis = planNodeCpuTime.toMillis(); + long planNodeBlockedTimeMillis = planNodeBlockedTime.toMillis(); + double planNodePhysicalInputReadNanos = planNodePhysicalInputReadTime.getValue(NANOSECONDS); + ListMultimap groupedOperatorStats = ArrayListMultimap.create(); + for (Map.Entry entry : this.operatorStats.entrySet()) { + groupedOperatorStats.put(entry.getKey(), entry.getValue()); + } + + for (PlanNodeStats other : others) { + checkArgument(planNodeId.equals(other.getPlanNodeId()), "planNodeIds do not match. %s != %s", planNodeId, other.getPlanNodeId()); + planNodeInputPositions += other.planNodeInputPositions; + planNodeOutputPositions += other.planNodeOutputPositions; + planNodeScheduledTimeMillis += other.planNodeScheduledTime.toMillis(); + planNodeCpuTimeMillis += other.planNodeCpuTime.toMillis(); + planNodeBlockedTimeMillis += other.planNodeBlockedTime.toMillis(); + planNodePhysicalInputReadNanos += other.planNodePhysicalInputReadTime.getValue(NANOSECONDS); + planNodePhysicalInputDataSizeBytes += other.planNodePhysicalInputDataSize.toBytes(); + planNodeInputDataSizeBytes += other.planNodeInputDataSize.toBytes(); + planNodeOutputDataSizeBytes += other.planNodeOutputDataSize.toBytes(); + planNodeSpilledDataSizeBytes += other.planNodeSpilledDataSize.toBytes(); + for (Map.Entry entry : other.operatorStats.entrySet()) { + groupedOperatorStats.put(entry.getKey(), entry.getValue()); + } + } + + ImmutableMap.Builder mergedOperatorStatsBuilder = ImmutableMap.builder(); + for (String key : groupedOperatorStats.keySet()) { + mergedOperatorStatsBuilder.put(key, BasicOperatorStats.merge(groupedOperatorStats.get(key))); + } + + return new PlanNodeStats( + planNodeId, + new Duration(planNodeScheduledTimeMillis, MILLISECONDS), + new Duration(planNodeCpuTimeMillis, MILLISECONDS), + new Duration(planNodeBlockedTimeMillis, MILLISECONDS), + planNodeInputPositions, + succinctBytes(planNodeInputDataSizeBytes), + succinctBytes(planNodePhysicalInputDataSizeBytes), + new Duration(planNodePhysicalInputReadNanos, NANOSECONDS), + planNodeOutputPositions, + succinctBytes(planNodeOutputDataSizeBytes), + succinctBytes(planNodeSpilledDataSizeBytes), + mergedOperatorStatsBuilder.buildOrThrow()); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStatsSummarizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStatsSummarizer.java index 9fdc5aea488c..aeb8376ffa7d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStatsSummarizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanNodeStatsSummarizer.java @@ -13,7 +13,9 @@ */ package io.trino.sql.planner.planprinter; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; import io.airlift.units.Duration; import io.trino.execution.StageInfo; import io.trino.execution.TaskInfo; @@ -51,15 +53,21 @@ public static Map aggregateStageStats(List public static Map aggregateTaskStats(List taskInfos) { - Map aggregatedStats = new HashMap<>(); + ListMultimap groupedStats = ArrayListMultimap.create(); List planNodeStats = taskInfos.stream() .map(TaskInfo::getStats) .flatMap(taskStats -> getPlanNodeStats(taskStats).stream()) .collect(toList()); for (PlanNodeStats stats : planNodeStats) { - aggregatedStats.merge(stats.getPlanNodeId(), stats, PlanNodeStats::mergeWith); + groupedStats.put(stats.getPlanNodeId(), stats); } - return aggregatedStats; + + ImmutableMap.Builder aggregatedStatsBuilder = ImmutableMap.builder(); + for (PlanNodeId planNodeId : groupedStats.keySet()) { + List groupedPlanNodeStats = groupedStats.get(planNodeId); + aggregatedStatsBuilder.put(planNodeId, groupedPlanNodeStats.get(0).mergeWith(groupedPlanNodeStats.subList(1, groupedPlanNodeStats.size()))); + } + return aggregatedStatsBuilder.buildOrThrow(); } private static List getPlanNodeStats(TaskStats taskStats)