diff --git a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java index c521a6e184fb6..63f74f2dd1ee2 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java @@ -40,6 +40,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; +import static com.facebook.presto.execution.MemoryRevokingUtils.getMemoryPools; import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.PER_TASK_MEMORY_THRESHOLD; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -111,15 +112,6 @@ private static double checkFraction(double value, String valueName) return value; } - private static List getMemoryPools(LocalMemoryManager localMemoryManager) - { - requireNonNull(localMemoryManager, "localMemoryManager can not be null"); - ImmutableList.Builder builder = new ImmutableList.Builder<>(); - builder.add(localMemoryManager.getGeneralPool()); - localMemoryManager.getReservedPool().ifPresent(builder::add); - return builder.build(); - } - @PostConstruct public void start() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingUtils.java b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingUtils.java new file mode 100644 index 0000000000000..07f0a061972dc --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingUtils.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.memory.LocalMemoryManager; +import com.facebook.presto.memory.MemoryPool; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class MemoryRevokingUtils +{ + private MemoryRevokingUtils() {} + + public static List getMemoryPools(LocalMemoryManager localMemoryManager) + { + requireNonNull(localMemoryManager, "localMemoryManager can not be null"); + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + builder.add(localMemoryManager.getGeneralPool()); + localMemoryManager.getReservedPool().ifPresent(builder::add); + return builder.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskManager.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskManager.java index 4b7e84cd2be98..821dae6322725 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlTaskManager.java @@ -316,6 +316,11 @@ public List getAllTasks() return ImmutableList.copyOf(tasks.asMap().values()); } + public SqlTask getTask(TaskId taskId) + { + return tasks.getUnchecked(taskId); + } + @Override public List getAllTaskInfo() { diff --git a/presto-main/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java index e622c8a3ebf97..fbe7737b47c15 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java @@ -14,11 +14,15 @@ package com.facebook.presto.execution; import com.facebook.airlift.log.Logger; +import com.facebook.presto.memory.LocalMemoryManager; +import com.facebook.presto.memory.MemoryPool; import com.facebook.presto.memory.QueryContext; +import com.facebook.presto.memory.TaskRevocableMemoryListener; import com.facebook.presto.memory.VoidTraversingQueryContextVisitor; import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import javax.annotation.Nullable; import javax.annotation.PostConstruct; @@ -31,8 +35,10 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; import java.util.function.Supplier; +import static com.facebook.presto.execution.MemoryRevokingUtils.getMemoryPools; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; @@ -40,7 +46,8 @@ public class TaskThresholdMemoryRevokingScheduler { private static final Logger log = Logger.get(TaskThresholdMemoryRevokingScheduler.class); - private final Supplier> currentTasksSupplier; + private final Supplier> allTasksSupplier; + private final Function taskSupplier; private final ScheduledExecutorService taskManagementExecutor; private final long maxRevocableMemoryPerTask; @@ -50,15 +57,20 @@ public class TaskThresholdMemoryRevokingScheduler private ScheduledFuture scheduledFuture; private final AtomicBoolean checkPending = new AtomicBoolean(); + private final List memoryPools; + private final TaskRevocableMemoryListener taskRevocableMemoryListener = TaskRevocableMemoryListener.onMemoryReserved(this::onMemoryReserved); @Inject public TaskThresholdMemoryRevokingScheduler( + LocalMemoryManager localMemoryManager, SqlTaskManager sqlTaskManager, TaskManagementExecutor taskManagementExecutor, FeaturesConfig config) { this( + ImmutableList.copyOf(getMemoryPools(localMemoryManager)), requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks, + requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getTask, requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor(), requireNonNull(config.getMaxRevocableMemoryPerTask(), "maxRevocableMemoryPerTask cannot be null").toBytes()); log.debug("Using TaskThresholdMemoryRevokingScheduler spilling strategy"); @@ -66,11 +78,15 @@ public TaskThresholdMemoryRevokingScheduler( @VisibleForTesting TaskThresholdMemoryRevokingScheduler( - Supplier> currentTasksSupplier, + List memoryPools, + Supplier> allTasksSupplier, + Function taskSupplier, ScheduledExecutorService taskManagementExecutor, long maxRevocableMemoryPerTask) { - this.currentTasksSupplier = requireNonNull(currentTasksSupplier, "currentTasksSupplier is null"); + this.memoryPools = ImmutableList.copyOf(requireNonNull(memoryPools, "memoryPools is null")); + this.allTasksSupplier = requireNonNull(allTasksSupplier, "allTasksSupplier is null"); + this.taskSupplier = requireNonNull(taskSupplier, "taskSupplier is null"); this.taskManagementExecutor = requireNonNull(taskManagementExecutor, "taskManagementExecutor is null"); this.maxRevocableMemoryPerTask = maxRevocableMemoryPerTask; } @@ -79,6 +95,7 @@ public TaskThresholdMemoryRevokingScheduler( public void start() { registerTaskMemoryPeriodicCheck(); + registerPoolListeners(); } private void registerTaskMemoryPeriodicCheck() @@ -100,6 +117,14 @@ public void stop() scheduledFuture.cancel(true); scheduledFuture = null; } + + memoryPools.forEach(memoryPool -> memoryPool.removeTaskRevocableMemoryListener(taskRevocableMemoryListener)); + } + + @VisibleForTesting + void registerPoolListeners() + { + memoryPools.forEach(memoryPool -> memoryPool.addTaskRevocableMemoryListener(taskRevocableMemoryListener)); } @VisibleForTesting @@ -110,16 +135,52 @@ void revokeHighMemoryTasksIfNeeded() } } + private void onMemoryReserved(TaskId taskId, MemoryPool memoryPool) + { + try { + SqlTask task = taskSupplier.apply(taskId); + if (!memoryRevokingNeeded(task)) { + return; + } + + if (checkPending.compareAndSet(false, true)) { + log.debug("Scheduling check for %s", memoryPool); + scheduleRevoking(); + } + } + catch (Throwable e) { + log.error(e, "Error when acting on memory pool reservation"); + } + } + + private void scheduleRevoking() + { + taskManagementExecutor.execute(() -> { + try { + revokeHighMemoryTasks(); + } + catch (Throwable e) { + log.error(e, "Error requesting memory revoking"); + } + }); + } + + private boolean memoryRevokingNeeded(SqlTask task) + { + return task.getTaskInfo().getStats().getRevocableMemoryReservationInBytes() >= maxRevocableMemoryPerTask; + } + private synchronized void revokeHighMemoryTasks() { if (checkPending.getAndSet(false)) { - Collection sqlTasks = requireNonNull(currentTasksSupplier.get()); + Collection sqlTasks = requireNonNull(allTasksSupplier.get()); for (SqlTask task : sqlTasks) { - long currentTaskRevocableMemory = task.getTaskInfo().getStats().getRevocableMemoryReservationInBytes(); - if (currentTaskRevocableMemory < maxRevocableMemoryPerTask) { + if (!memoryRevokingNeeded(task)) { continue; } + long currentTaskRevocableMemory = task.getTaskInfo().getStats().getRevocableMemoryReservationInBytes(); + AtomicLong remainingBytesToRevokeAtomic = new AtomicLong(currentTaskRevocableMemory - maxRevocableMemoryPerTask); task.getQueryContext().accept(new VoidTraversingQueryContextVisitor() { diff --git a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java index 73f53f9b3bc11..1a49b8b62c951 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/MemoryPool.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.memory; +import com.facebook.presto.execution.TaskId; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryAllocation; import com.facebook.presto.spi.memory.MemoryPoolId; @@ -71,6 +72,8 @@ public class MemoryPool private final List listeners = new CopyOnWriteArrayList<>(); + private final List taskRevocableMemoryListeners = new CopyOnWriteArrayList<>(); + public MemoryPool(MemoryPoolId id, DataSize size) { this.id = requireNonNull(id, "name is null"); @@ -106,6 +109,16 @@ public void removeListener(MemoryPoolListener listener) listeners.remove(requireNonNull(listener, "listener cannot be null")); } + public void addTaskRevocableMemoryListener(TaskRevocableMemoryListener listener) + { + taskRevocableMemoryListeners.add(requireNonNull(listener, "listener cannot be null")); + } + + public void removeTaskRevocableMemoryListener(TaskRevocableMemoryListener listener) + { + taskRevocableMemoryListeners.remove(requireNonNull(listener, "listener cannot be null")); + } + /** * Reserves the given number of bytes. Caller should wait on the returned future, before allocating more memory. */ @@ -141,6 +154,11 @@ private void onMemoryReserved() listeners.forEach(listener -> listener.onMemoryReserved(this)); } + public void onTaskMemoryReserved(TaskId taskId) + { + taskRevocableMemoryListeners.forEach(listener -> listener.onMemoryReserved(taskId, this)); + } + public ListenableFuture reserveRevocable(QueryId queryId, long bytes) { checkArgument(bytes >= 0, "bytes is negative"); diff --git a/presto-main/src/main/java/com/facebook/presto/memory/TaskRevocableMemoryListener.java b/presto-main/src/main/java/com/facebook/presto/memory/TaskRevocableMemoryListener.java new file mode 100644 index 0000000000000..a2e7eae44c19f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/memory/TaskRevocableMemoryListener.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.memory; + +import com.facebook.presto.execution.TaskId; + +import java.util.function.BiConsumer; + +public interface TaskRevocableMemoryListener +{ + /** + * Listener function that is called when a Task reserves + * memory in a given MemoryPool successfully + * + * @param taskId the {@link TaskId} of the task that reserved the memory + * @param memoryPool the {@link MemoryPool} where the reservation took place + */ + void onMemoryReserved(TaskId taskId, MemoryPool memoryPool); + + static TaskRevocableMemoryListener onMemoryReserved(BiConsumer action) + { + return action::accept; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java index fa2100e455c1d..3fd4b744881e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/OperatorContext.java @@ -255,7 +255,7 @@ public LocalMemoryContext localSystemMemoryContext() // caller shouldn't close this context as it's managed by the OperatorContext public LocalMemoryContext localRevocableMemoryContext() { - return new InternalLocalMemoryContext(operatorMemoryContext.localRevocableMemoryContext(), revocableMemoryFuture, () -> {}, false); + return new InternalLocalMemoryContext(operatorMemoryContext.localRevocableMemoryContext(), revocableMemoryFuture, this::updateTaskRevocableMemoryReservation, false); } // caller shouldn't close this context as it's managed by the OperatorContext @@ -267,7 +267,7 @@ public AggregatedMemoryContext aggregateUserMemoryContext() // caller shouldn't close this context as it's managed by the OperatorContext public AggregatedMemoryContext aggregateRevocableMemoryContext() { - return new InternalAggregatedMemoryContext(operatorMemoryContext.aggregateRevocableMemoryContext(), memoryFuture, () -> {}, false); + return new InternalAggregatedMemoryContext(operatorMemoryContext.aggregateRevocableMemoryContext(), memoryFuture, this::updateTaskRevocableMemoryReservation, false); } // caller should close this context as it's a new context @@ -287,6 +287,12 @@ private void updatePeakMemoryReservations() peakTotalMemoryReservation.accumulateAndGet(totalMemory, Math::max); } + // listen to revocable memory allocations and call any listeners waiting on task memory allocation + private void updateTaskRevocableMemoryReservation() + { + driverContext.getPipelineContext().getTaskContext().getQueryContext().getMemoryPool().onTaskMemoryReserved(driverContext.getTaskId()); + } + public long getReservedRevocableBytes() { return operatorMemoryContext.getRevocableMemory(); diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestMemoryRevokingScheduler.java b/presto-main/src/test/java/com/facebook/presto/execution/TestMemoryRevokingScheduler.java index 01c761dbac0fa..b527074f4823b 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestMemoryRevokingScheduler.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestMemoryRevokingScheduler.java @@ -41,6 +41,7 @@ import com.google.common.base.Functions; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import io.airlift.units.DataSize; @@ -338,8 +339,9 @@ public void testTaskThresholdRevokingScheduler() allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2); List tasks = ImmutableList.of(sqlTask1, sqlTask2); + ImmutableMap taskMap = ImmutableMap.of(sqlTask1.getTaskId(), sqlTask1, sqlTask2.getTaskId(), sqlTask2); TaskThresholdMemoryRevokingScheduler scheduler = new TaskThresholdMemoryRevokingScheduler( - () -> tasks, executor, 5L); + singletonList(memoryPool), () -> tasks, taskMap::get, executor, 5L); assertMemoryRevokingNotRequested(); @@ -386,6 +388,70 @@ public void testTaskThresholdRevokingScheduler() assertMemoryRevokingRequestedFor(operatorContext2); } + @Test + public void testTaskThresholdRevokingSchedulerImmediate() + throws Exception + { + SqlTask sqlTask1 = newSqlTask(); + TestOperatorContext operatorContext11 = createTestingOperatorContexts(sqlTask1, "operator11"); + TestOperatorContext operatorContext12 = createTestingOperatorContexts(sqlTask1, "operator12"); + + SqlTask sqlTask2 = newSqlTask(); + TestOperatorContext operatorContext2 = createTestingOperatorContexts(sqlTask2, "operator2"); + + allOperatorContexts = ImmutableSet.of(operatorContext11, operatorContext12, operatorContext2); + List tasks = ImmutableList.of(sqlTask1, sqlTask2); + ImmutableMap taskMap = ImmutableMap.of(sqlTask1.getTaskId(), sqlTask1, sqlTask2.getTaskId(), sqlTask2); + TaskThresholdMemoryRevokingScheduler scheduler = new TaskThresholdMemoryRevokingScheduler( + singletonList(memoryPool), () -> tasks, taskMap::get, executor, 5L); + scheduler.registerPoolListeners(); // no periodic check initiated + + assertMemoryRevokingNotRequested(); + + operatorContext11.localRevocableMemoryContext().setBytes(3); + operatorContext2.localRevocableMemoryContext().setBytes(2); + // at this point, Task1 = 3 total bytes, Task2 = 2 total bytes + + // this ensures that we are waiting for the memory revocation listener and not using polling-based revoking + awaitAsynchronousCallbacksRun(); + assertMemoryRevokingNotRequested(); + + operatorContext12.localRevocableMemoryContext().setBytes(3); + // at this point, Task1 = 6 total bytes, Task2 = 2 total bytes + + awaitAsynchronousCallbacksRun(); + // only operator11 should revoke since we need to revoke only 1 byte + // threshold - (operator11 + operator12) => 5 - (3 + 3) = 1 bytes to revoke + assertMemoryRevokingRequestedFor(operatorContext11); + + // revoke 2 bytes in operator11 + operatorContext11.localRevocableMemoryContext().setBytes(1); + // at this point, Task1 = 3 total bytes, Task2 = 2 total bytes + operatorContext11.resetMemoryRevokingRequested(); + awaitAsynchronousCallbacksRun(); + assertMemoryRevokingNotRequested(); + + operatorContext12.localRevocableMemoryContext().setBytes(6); // operator12 fills up + // at this point, Task1 = 7 total bytes, Task2 = 2 total bytes + awaitAsynchronousCallbacksRun(); + // both operator11 and operator 12 are revoking since we revoke in order of operator creation within the task until we are below the memory revoking threshold + assertMemoryRevokingRequestedFor(operatorContext11, operatorContext12); + + operatorContext11.localRevocableMemoryContext().setBytes(2); + operatorContext11.resetMemoryRevokingRequested(); + operatorContext12.localRevocableMemoryContext().setBytes(2); + operatorContext12.resetMemoryRevokingRequested(); + // at this point, Task1 = 4 total bytes, Task2 = 2 total bytes + + awaitAsynchronousCallbacksRun(); + assertMemoryRevokingNotRequested(); // no need to revoke + + operatorContext2.localRevocableMemoryContext().setBytes(6); + // at this point, Task1 = 4 total bytes, Task2 = 6 total bytes, operators in Task2 must be revoked + awaitAsynchronousCallbacksRun(); + assertMemoryRevokingRequestedFor(operatorContext2); + } + private OperatorContext createContexts(SqlTask sqlTask) { TaskContext taskContext = sqlTask.getQueryContext().addTaskContext(new TaskStateMachine(new TaskId("q", 1, 0, 1), executor), session, false, false, false, false, false);