diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index 41dcb537ff14..f637e20ad618 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -615,6 +615,11 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional get(StageId stageId, int partitionI return Optional.of(storage.get(stageId, partitionId)); } + /** + * Removes {@link TaskDescriptor} for a task identified by the stageId and partitionId. + * If the query has been terminated the call is ignored. + * + * @throws java.util.NoSuchElementException if {@link TaskDescriptor} for a given task does not exist + */ + public synchronized void remove(StageId stageId, int partitionId) + { + TaskDescriptors storage = storages.get(stageId.getQueryId()); + if (storage == null) { + // query has been terminated + return; + } + long previousReservedBytes = storage.getReservedBytes(); + storage.remove(stageId, partitionId); + long currentReservedBytes = storage.getReservedBytes(); + long delta = currentReservedBytes - previousReservedBytes; + updateMemoryReservation(delta); + } + /** * Notifies the storage that the query with a given queryId has been finished and the task descriptors can be safely discarded. *

@@ -178,6 +198,17 @@ public TaskDescriptor get(StageId stageId, int partitionId) return descriptor; } + public void remove(StageId stageId, int partitionId) + { + throwIfFailed(); + TaskDescriptorKey key = new TaskDescriptorKey(stageId, partitionId); + TaskDescriptor descriptor = descriptors.remove(key); + if (descriptor == null) { + throw new NoSuchElementException(format("descriptor not found for key %s", key)); + } + reservedBytes -= descriptor.getRetainedSizeInBytes(); + } + public long getReservedBytes() { return reservedBytes; diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java index ef371a4c273c..937273844b27 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java @@ -80,8 +80,17 @@ public void testHappyPath() .flatMap(TestTaskDescriptorStorage::getCatalogName) .contains("catalog6"); - manager.destroy(QUERY_1); - manager.destroy(QUERY_2); + manager.remove(QUERY_1_STAGE_1, 0); + manager.remove(QUERY_2_STAGE_2, 1); + + assertThatThrownBy(() -> manager.get(QUERY_1_STAGE_1, 0)) + .hasMessageContaining("descriptor not found for key"); + assertThatThrownBy(() -> manager.get(QUERY_2_STAGE_2, 1)) + .hasMessageContaining("descriptor not found for key"); + + assertThat(manager.getReservedBytes()) + .isGreaterThanOrEqualTo(toBytes(5, KILOBYTE)) + .isLessThanOrEqualTo(toBytes(7, KILOBYTE)); } @Test @@ -140,6 +149,12 @@ public void testCapacityExceeded() .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); assertThatThrownBy(() -> manager.get(QUERY_1_STAGE_2, 0)) .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); + assertThatThrownBy(() -> manager.remove(QUERY_1_STAGE_1, 0)) + .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); + assertThatThrownBy(() -> manager.remove(QUERY_1_STAGE_1, 1)) + .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); + assertThatThrownBy(() -> manager.remove(QUERY_1_STAGE_2, 0)) + .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); // QUERY_2 is still active assertThat(manager.get(QUERY_2_STAGE_1, 0)) @@ -160,6 +175,8 @@ public void testCapacityExceeded() .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); assertThatThrownBy(() -> manager.get(QUERY_2_STAGE_1, 0)) .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); + assertThatThrownBy(() -> manager.remove(QUERY_2_STAGE_1, 0)) + .matches(TestTaskDescriptorStorage::isStorageCapacityExceededFailure); } private static TaskDescriptor createTaskDescriptor(int partitionId, DataSize retainedSize)