diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index 82ab33861f81..6524ffe5e391 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -293,10 +293,14 @@ private class MemoryUsageListener { private long previousUserMemory; private long previousRevocableMemory; + private boolean finalUsageReported; @Override public synchronized void stateChanged(TaskStatus taskStatus) { + if (finalUsageReported) { + return; + } long currentUserMemory = taskStatus.getMemoryReservation().toBytes(); long currentRevocableMemory = taskStatus.getRevocableMemoryReservation().toBytes(); long deltaUserMemoryInBytes = currentUserMemory - previousUserMemory; @@ -305,6 +309,14 @@ public synchronized void stateChanged(TaskStatus taskStatus) previousUserMemory = currentUserMemory; previousRevocableMemory = currentRevocableMemory; stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes); + + if (taskStatus.getState().isDone()) { + // if task is finished perform final memory update to 0 + stateMachine.updateMemoryUsage(-currentUserMemory, -currentRevocableMemory, -(currentUserMemory + currentRevocableMemory)); + previousUserMemory = 0; + previousRevocableMemory = 0; + finalUsageReported = true; + } } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index c9bdeabaffa2..67c26a31c81f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -374,9 +374,9 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) long cumulativeUserMemory = 0; long failedCumulativeUserMemory = 0; - long userMemoryReservation = 0; - long revocableMemoryReservation = 0; - long totalMemoryReservation = 0; + long userMemoryReservation = currentUserMemory.get(); + long revocableMemoryReservation = currentRevocableMemory.get(); + long totalMemoryReservation = currentTotalMemory.get(); long peakUserMemoryReservation = peakUserMemory.get(); long peakRevocableMemoryReservation = peakRevocableMemory.get(); @@ -459,12 +459,6 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) failedCumulativeUserMemory += taskStats.getCumulativeUserMemory(); } - long taskUserMemory = taskStats.getUserMemoryReservation().toBytes(); - long taskRevocableMemory = taskStats.getRevocableMemoryReservation().toBytes(); - userMemoryReservation += taskUserMemory; - revocableMemoryReservation += taskRevocableMemory; - totalMemoryReservation += taskUserMemory + taskRevocableMemory; - totalScheduledTime += taskStats.getTotalScheduledTime().roundTo(NANOSECONDS); totalCpuTime += taskStats.getTotalCpuTime().roundTo(NANOSECONDS); totalBlockedTime += taskStats.getTotalBlockedTime().roundTo(NANOSECONDS);