Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ public void testCreateWithNullableColumns()
}
}

@Test
// disabled due to https://github.com/prestodb/presto/issues/16081
@Test(enabled = false)
public void testAlterColumns()
{
String tableName = randomUUID().toString().toUpperCase(ENGLISH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,8 @@ public void testSetRole()
}
}

@Test(timeOut = 10000)
// Disabled due to https://github.com/prestodb/presto/issues/16080
@Test(enabled = false, timeOut = 10000)
public void testQueryCancelByInterrupt()
throws Exception
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.memory.LocalMemoryManager;
import com.facebook.presto.memory.MemoryPool;
import com.facebook.presto.memory.MemoryPoolListener;
import com.facebook.presto.memory.QueryContext;
import com.facebook.presto.memory.VoidTraversingQueryContextVisitor;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.PipelineContext;
Expand All @@ -29,85 +30,93 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;

import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.facebook.airlift.concurrent.Threads.threadsNamed;
import static com.facebook.presto.execution.MemoryRevokingUtils.getMemoryPools;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.PER_TASK_MEMORY_THRESHOLD;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.concurrent.Executors.newSingleThreadExecutor;

public class MemoryRevokingScheduler
{
private static final Logger log = Logger.get(MemoryRevokingScheduler.class);

private static final Ordering<SqlTask> ORDER_BY_CREATE_TIME = Ordering.natural().onResultOf(SqlTask::getTaskCreatedTime);

private final List<MemoryPool> memoryPools;
private final Function<QueryId, QueryContext> queryContextSupplier;
private final Supplier<List<SqlTask>> currentTasksSupplier;
private final ScheduledExecutorService taskManagementExecutor;
private final ExecutorService memoryRevocationExecutor;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this instance is created in the constructor, it should be shutdown in the stop() method to avoid leaking it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't created in the constructor for the tests, but I'll actually just change the tests to get the executor from here, and that will remove any need to do any checking about the pool size.

private final double memoryRevokingThreshold;
private final double memoryRevokingTarget;
private final TaskSpillingStrategy spillingStrategy;

private final List<MemoryPool> memoryPools;
private final MemoryPoolListener memoryPoolListener = this::onMemoryReserved;

@Nullable
private ScheduledFuture<?> scheduledFuture;

private final AtomicBoolean checkPending = new AtomicBoolean();
private final boolean queryLimitSpillEnabled;

@Inject
public MemoryRevokingScheduler(
LocalMemoryManager localMemoryManager,
SqlTaskManager sqlTaskManager,
TaskManagementExecutor taskManagementExecutor,
FeaturesConfig config)
{
this(
ImmutableList.copyOf(getMemoryPools(localMemoryManager)),
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks,
requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor(),
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getQueryContext,
config.getMemoryRevokingThreshold(),
config.getMemoryRevokingTarget(),
config.getTaskSpillingStrategy());
config.getTaskSpillingStrategy(),
config.isQueryLimitSpillEnabled());
}

@VisibleForTesting
MemoryRevokingScheduler(
List<MemoryPool> memoryPools,
Supplier<List<SqlTask>> currentTasksSupplier,
ScheduledExecutorService taskManagementExecutor,
Function<QueryId, QueryContext> queryContextSupplier,
double memoryRevokingThreshold,
double memoryRevokingTarget,
TaskSpillingStrategy taskSpillingStrategy)
TaskSpillingStrategy taskSpillingStrategy,
boolean queryLimitSpillEnabled)
{
this.memoryPools = ImmutableList.copyOf(requireNonNull(memoryPools, "memoryPools is null"));
this.currentTasksSupplier = requireNonNull(currentTasksSupplier, "currentTasksSupplier is null");
this.taskManagementExecutor = requireNonNull(taskManagementExecutor, "taskManagementExecutor is null");
this.currentTasksSupplier = requireNonNull(currentTasksSupplier, "allTasksSupplier is null");
this.queryContextSupplier = requireNonNull(queryContextSupplier, "queryContextSupplier is null");
this.memoryRevokingThreshold = checkFraction(memoryRevokingThreshold, "memoryRevokingThreshold");
this.memoryRevokingTarget = checkFraction(memoryRevokingTarget, "memoryRevokingTarget");
// by using a single thread executor, we don't need to worry about locking to ensure only
// one revocation request per-query/memory pool is processed at a time.
this.memoryRevocationExecutor = newSingleThreadExecutor(threadsNamed("memory-revocation"));
this.spillingStrategy = requireNonNull(taskSpillingStrategy, "taskSpillingStrategy is null");
checkArgument(spillingStrategy != PER_TASK_MEMORY_THRESHOLD, "spilling strategy cannot be PER_TASK_MEMORY_THRESHOLD in MemoryRevokingScheduler");
checkArgument(
memoryRevokingTarget <= memoryRevokingThreshold,
"memoryRevokingTarget should be less than or equal memoryRevokingThreshold, but got %s and %s respectively",
memoryRevokingTarget, memoryRevokingThreshold);
this.queryLimitSpillEnabled = queryLimitSpillEnabled;
}

private static double checkFraction(double value, String valueName)
Expand All @@ -120,95 +129,131 @@ private static double checkFraction(double value, String valueName)
@PostConstruct
public void start()
{
registerPeriodicCheck();
registerPoolListeners();
}

private void registerPeriodicCheck()
{
this.scheduledFuture = taskManagementExecutor.scheduleWithFixedDelay(() -> {
try {
requestMemoryRevokingIfNeeded();
}
catch (Exception e) {
log.error(e, "Error requesting system memory revoking");
}
}, 1, 1, SECONDS);
}

@PreDestroy
public void stop()
{
if (scheduledFuture != null) {
scheduledFuture.cancel(true);
scheduledFuture = null;
}

memoryPools.forEach(memoryPool -> memoryPool.removeListener(memoryPoolListener));
memoryRevocationExecutor.shutdown();
}

@VisibleForTesting
void registerPoolListeners()
private void registerPoolListeners()
{
memoryPools.forEach(memoryPool -> memoryPool.addListener(memoryPoolListener));
}

@VisibleForTesting
void awaitAsynchronousCallbacksRun()
throws InterruptedException
{
memoryRevocationExecutor.invokeAll(singletonList((Callable<?>) () -> null));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever, I didn't realize that invokeAll() blocked until completion like this.

}

private void onMemoryReserved(MemoryPool memoryPool, QueryId queryId, long queryMemoryReservation)
{
try {
if (!memoryRevokingNeeded(memoryPool)) {
return;
if (queryLimitSpillEnabled) {
QueryContext queryContext = queryContextSupplier.apply(queryId);
verify(queryContext != null, "QueryContext not found for queryId %s", queryId);
long maxTotalMemory = queryContext.getMaxTotalMemory();
if (memoryRevokingNeededForQuery(queryMemoryReservation, maxTotalMemory)) {
log.debug("Scheduling check for %s", queryId);
scheduleQueryRevoking(queryContext, maxTotalMemory);
}
}

if (checkPending.compareAndSet(false, true)) {
if (memoryRevokingNeededForPool(memoryPool)) {
log.debug("Scheduling check for %s", memoryPool);
scheduleRevoking();
scheduleMemoryPoolRevoking(memoryPool);
}
}
catch (Exception e) {
log.error(e, "Error when acting on memory pool reservation");
}
}

@VisibleForTesting
void requestMemoryRevokingIfNeeded()
private boolean memoryRevokingNeededForQuery(long queryMemoryReservation, long maxTotalMemory)
{
if (checkPending.compareAndSet(false, true)) {
runMemoryRevoking();
}
return queryMemoryReservation >= maxTotalMemory;
}

private void scheduleRevoking()
private void scheduleQueryRevoking(QueryContext queryContext, long maxTotalMemory)
{
taskManagementExecutor.execute(() -> {
memoryRevocationExecutor.execute(() -> {
try {
runMemoryRevoking();
revokeQueryMemory(queryContext, maxTotalMemory);
}
catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}

private synchronized void runMemoryRevoking()
private void revokeQueryMemory(QueryContext queryContext, long maxTotalMemory)
{
if (checkPending.getAndSet(false)) {
Collection<SqlTask> allTasks = null;
for (MemoryPool memoryPool : memoryPools) {
if (!memoryRevokingNeeded(memoryPool)) {
continue;
QueryId queryId = queryContext.getQueryId();
MemoryPool memoryPool = queryContext.getMemoryPool();
// get a fresh value for queryTotalMemory in case it's changed (e.g. by a previous revocation request)
long queryTotalMemory = getTotalQueryMemoryReservation(queryId, memoryPool);
// order tasks by decreasing revocableMemory so that we don't spill more tasks than needed
SortedMap<Long, TaskContext> queryTaskContextsMap = new TreeMap<>(Comparator.reverseOrder());
queryContext.getAllTaskContexts()
.forEach(taskContext -> queryTaskContextsMap.put(taskContext.getTaskMemoryContext().getRevocableMemory(), taskContext));

AtomicLong remainingBytesToRevoke = new AtomicLong(queryTotalMemory - maxTotalMemory);
Collection<TaskContext> queryTaskContexts = queryTaskContextsMap.values();
remainingBytesToRevoke.addAndGet(-MemoryRevokingSchedulerUtils.getMemoryAlreadyBeingRevoked(queryTaskContexts, remainingBytesToRevoke.get()));
for (TaskContext taskContext : queryTaskContexts) {
if (remainingBytesToRevoke.get() <= 0) {
break;
}
taskContext.accept(new VoidTraversingQueryContextVisitor<AtomicLong>()
{
@Override
public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong remainingBytesToRevoke)
{
if (remainingBytesToRevoke.get() > 0) {
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
remainingBytesToRevoke.addAndGet(-revokedBytes);
log.debug("taskId=%s: requested revoking %s; remaining %s", taskContext.getTaskId(), revokedBytes, remainingBytesToRevoke);
}
}
return null;
}
}, remainingBytesToRevoke);
}
}

if (allTasks == null) {
allTasks = requireNonNull(currentTasksSupplier.get());
}
private static long getTotalQueryMemoryReservation(QueryId queryId, MemoryPool memoryPool)
{
return memoryPool.getQueryMemoryReservation(queryId) + memoryPool.getQueryRevocableMemoryReservation(queryId);
}

requestMemoryRevoking(memoryPool, allTasks);
private void scheduleMemoryPoolRevoking(MemoryPool memoryPool)
{
memoryRevocationExecutor.execute(() -> {
try {
runMemoryPoolRevoking(memoryPool);
}
catch (Exception e) {
log.error(e, "Error requesting memory revoking");
}
});
}

@VisibleForTesting
void runMemoryPoolRevoking(MemoryPool memoryPool)
{
if (!memoryRevokingNeededForPool(memoryPool)) {
return;
}
Collection<SqlTask> allTasks = requireNonNull(currentTasksSupplier.get());
requestMemoryPoolRevoking(memoryPool, allTasks);
}

private void requestMemoryRevoking(MemoryPool memoryPool, Collection<SqlTask> allTasks)
private void requestMemoryPoolRevoking(MemoryPool memoryPool, Collection<SqlTask> allTasks)
{
long remainingBytesToRevoke = (long) (-memoryPool.getFreeBytes() + (memoryPool.getMaxBytes() * (1.0 - memoryRevokingTarget)));
ArrayList<SqlTask> runningTasksInPool = findRunningTasksInMemoryPool(allTasks, memoryPool);
Expand All @@ -218,7 +263,7 @@ private void requestMemoryRevoking(MemoryPool memoryPool, Collection<SqlTask> al
}
}

private boolean memoryRevokingNeeded(MemoryPool memoryPool)
private boolean memoryRevokingNeededForPool(MemoryPool memoryPool)
{
return memoryPool.getReservedRevocableBytes() > 0
&& memoryPool.getFreeBytes() <= memoryPool.getMaxBytes() * (1.0 - memoryRevokingThreshold);
Expand Down
Loading