Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.EnumMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
Expand All @@ -50,8 +50,8 @@
import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ClassLoaderUtils;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ToStringBuilder;
import org.junit.platform.engine.ConfigurationParameters;
import org.junit.platform.engine.UniqueId;

/**
* @since 6.1
Expand All @@ -76,10 +76,12 @@ public ConcurrentHierarchicalTestExecutorService(ParallelExecutionConfiguration

ConcurrentHierarchicalTestExecutorService(ParallelExecutionConfiguration configuration, ClassLoader classLoader) {
ThreadFactory threadFactory = new WorkerThreadFactory(classLoader);
threadPool = new ThreadPoolExecutor(configuration.getCorePoolSize(), configuration.getMaxPoolSize(),
configuration.getKeepAliveSeconds(), SECONDS, new SynchronousQueue<>(), threadFactory);
parallelism = configuration.getParallelism();
workerLeaseManager = new WorkerLeaseManager(parallelism, this::maybeStartWorker);
var rejectedExecutionHandler = new LeaseAwareRejectedExecutionHandler(workerLeaseManager);
threadPool = new ThreadPoolExecutor(configuration.getCorePoolSize(), configuration.getMaxPoolSize(),
configuration.getKeepAliveSeconds(), SECONDS, new SynchronousQueue<>(), threadFactory,
rejectedExecutionHandler);
LOGGER.trace(() -> "initialized thread pool for parallelism of " + configuration.getParallelism());
}

Expand Down Expand Up @@ -146,26 +148,25 @@ private void maybeStartWorker(BooleanSupplier doneCondition) {
if (workerLease == null) {
return;
}
try {
threadPool.execute(() -> {
LOGGER.trace(() -> "starting worker");
try {
WorkerThread.getOrThrow().processQueueEntries(workerLease, doneCondition);
}
finally {
workerLease.release(false);
LOGGER.trace(() -> "stopping worker");
}
maybeStartWorker(doneCondition);
});
}
catch (RejectedExecutionException e) {
workerLease.release(false);
if (threadPool.isShutdown() || workerLeaseManager.isAtLeastOneLeaseTaken()) {
return;
threadPool.execute(new RunLeaseAwareWorker(workerLease,
() -> WorkerThread.getOrThrow().processQueueEntries(workerLease, doneCondition),
() -> this.maybeStartWorker(doneCondition)));
}

private record RunLeaseAwareWorker(WorkerLease workerLease, Runnable work, Runnable onWorkerFinished)
implements Runnable {

@Override
public void run() {
LOGGER.trace(() -> "starting worker");
try {
work.run();
}
finally {
workerLease.release(false);
LOGGER.trace(() -> "stopping worker");
}
LOGGER.error(e, () -> "failed to submit worker to thread pool");
throw e;
onWorkerFinished.run();
}
}

Expand Down Expand Up @@ -226,13 +227,25 @@ void processQueueEntries(WorkerLease workerLease, BooleanSupplier doneCondition)
LOGGER.trace(() -> "yielding resource lock");
break;
}
var entry = workQueue.poll();
if (entry == null) {
LOGGER.trace(() -> "no queue entry available");
if (workQueue.isEmpty()) {
LOGGER.trace(() -> "no queue entries available");
break;
}
LOGGER.trace(() -> "processing: " + entry.task);
execute(entry);
processQueueEntries();
}
}

private void processQueueEntries() {
var queueEntriesByResult = tryToStealWorkWithoutBlocking(workQueue);
var queueModified = queueEntriesByResult.containsKey(WorkStealResult.EXECUTED_BY_THIS_WORKER) //
|| queueEntriesByResult.containsKey(WorkStealResult.EXECUTED_BY_DIFFERENT_WORKER);
if (queueModified) {
return;
}
var entriesRequiringResourceLocks = queueEntriesByResult.get(WorkStealResult.RESOURCE_LOCK_UNAVAILABLE);
if (entriesRequiringResourceLocks != null) {
// One entry at a time to avoid blocking too much
tryToStealWork(entriesRequiringResourceLocks.get(0), BlockingMode.BLOCKING);
Comment on lines +247 to +248
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on this? Is there evidence this is better than doing all of them?

}
}

Expand Down Expand Up @@ -265,18 +278,17 @@ void invokeAll(List<? extends TestTask> testTasks) {

List<TestTask> isolatedTasks = new ArrayList<>(testTasks.size());
List<TestTask> sameThreadTasks = new ArrayList<>(testTasks.size());
var forkedChildren = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
var reverseQueueEntries = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
executeAll(sameThreadTasks);
var queueEntriesByResult = tryToStealWorkWithoutBlocking(forkedChildren);
tryToStealWorkWithBlocking(queueEntriesByResult);
waitFor(queueEntriesByResult);
var reverseQueueEntriesByResult = tryToStealWorkWithoutBlocking(reverseQueueEntries);
tryToStealWorkWithBlocking(reverseQueueEntriesByResult);
waitFor(reverseQueueEntriesByResult);
executeAll(isolatedTasks);
}

private List<WorkQueue.Entry> forkConcurrentChildren(List<? extends TestTask> children,
Consumer<TestTask> isolatedTaskCollector, List<TestTask> sameThreadTasks) {

int index = 0;
List<WorkQueue.Entry> queueEntries = new ArrayList<>(children.size());
for (TestTask child : children) {
if (requiresGlobalReadWriteLock(child)) {
Expand All @@ -286,7 +298,7 @@ else if (child.getExecutionMode() == SAME_THREAD) {
sameThreadTasks.add(child);
}
else {
queueEntries.add(WorkQueue.Entry.createWithIndex(child, index++));
queueEntries.add(workQueue.createEntry(child));
}
}

Expand All @@ -299,32 +311,29 @@ else if (child.getExecutionMode() == SAME_THREAD) {
}
forkAll(queueEntries);
}

queueEntries.sort(reverseOrder());
return queueEntries;
}

private Map<WorkStealResult, List<WorkQueue.Entry>> tryToStealWorkWithoutBlocking(
List<WorkQueue.Entry> forkedChildren) {
Iterable<WorkQueue.Entry> queueEntries) {

Map<WorkStealResult, List<WorkQueue.Entry>> queueEntriesByResult = new EnumMap<>(WorkStealResult.class);
if (!forkedChildren.isEmpty()) {
forkedChildren.sort(reverseOrder());
tryToStealWork(forkedChildren, BlockingMode.NON_BLOCKING, queueEntriesByResult);
}
tryToStealWork(queueEntries, BlockingMode.NON_BLOCKING, queueEntriesByResult);
return queueEntriesByResult;
}

private void tryToStealWorkWithBlocking(Map<WorkStealResult, List<WorkQueue.Entry>> queueEntriesByResult) {
var childrenRequiringResourceLocks = queueEntriesByResult.remove(WorkStealResult.RESOURCE_LOCK_UNAVAILABLE);
if (childrenRequiringResourceLocks == null) {
var entriesRequiringResourceLocks = queueEntriesByResult.remove(WorkStealResult.RESOURCE_LOCK_UNAVAILABLE);
if (entriesRequiringResourceLocks == null) {
return;
}
tryToStealWork(childrenRequiringResourceLocks, BlockingMode.BLOCKING, queueEntriesByResult);
tryToStealWork(entriesRequiringResourceLocks, BlockingMode.BLOCKING, queueEntriesByResult);
}

private void tryToStealWork(List<WorkQueue.Entry> children, BlockingMode blocking,
private void tryToStealWork(Iterable<WorkQueue.Entry> entries, BlockingMode blocking,
Map<WorkStealResult, List<WorkQueue.Entry>> queueEntriesByResult) {
for (var entry : children) {
for (var entry : entries) {
var state = tryToStealWork(entry, blocking);
queueEntriesByResult.computeIfAbsent(state, __ -> new ArrayList<>()).add(entry);
}
Expand Down Expand Up @@ -528,24 +537,27 @@ private enum BlockingMode {
NON_BLOCKING, BLOCKING
}

private static class WorkQueue {

private final EntryOrdering ordering = new EntryOrdering();
private final Queue<Entry> queue = new PriorityBlockingQueue<>();
Comment on lines -533 to -534
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I forgot to use the custom comparator... 😳

private static class WorkQueue implements Iterable<WorkQueue.Entry> {
private final AtomicLong index = new AtomicLong();
private final Set<Entry> queue = new ConcurrentSkipListSet<>();

Entry add(TestTask task) {
Entry entry = Entry.create(task);
Entry entry = createEntry(task);
LOGGER.trace(() -> "forking: " + entry.task);
return doAdd(entry);
}

Entry createEntry(TestTask task) {
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
return new Entry(task, new CompletableFuture<>(), level, index.getAndIncrement());
}

void addAll(Collection<Entry> entries) {
entries.forEach(this::doAdd);
}

void reAdd(Entry entry) {
LOGGER.trace(() -> "re-enqueuing: " + entry.task);
ordering.incrementAttempts(entry);
doAdd(entry);
}

Expand All @@ -557,11 +569,6 @@ private Entry doAdd(Entry entry) {
return entry;
}

@Nullable
Entry poll() {
return queue.poll();
}

boolean remove(Entry entry) {
return queue.remove(entry);
}
Expand All @@ -570,17 +577,13 @@ boolean isEmpty() {
return queue.isEmpty();
}

private record Entry(TestTask task, CompletableFuture<@Nullable Void> future, int level, int index)
implements Comparable<Entry> {

static Entry create(TestTask task) {
return createWithIndex(task, 0);
}
@Override
public Iterator<Entry> iterator() {
return queue.iterator();
}

static Entry createWithIndex(TestTask task, int index) {
int level = task.getTestDescriptor().getUniqueId().getSegments().size();
return new Entry(task, new CompletableFuture<>(), level, index);
}
private record Entry(TestTask task, CompletableFuture<@Nullable Void> future, int level, long index)
implements Comparable<Entry> {

@SuppressWarnings("FutureReturnValueIgnored")
Entry {
Expand All @@ -604,52 +607,14 @@ public int compareTo(Entry that) {
if (result != 0) {
return result;
}
return Integer.compare(that.index, this.index);
return Long.compare(that.index, this.index);
}

private boolean isContainer() {
return task.getTestDescriptor().isContainer();
}

}

static class EntryOrdering implements Comparator<Entry> {

private final ConcurrentMap<UniqueId, Integer> attempts = new ConcurrentHashMap<>();

@Override
public int compare(Entry a, Entry b) {
var result = a.compareTo(b);
if (result == 0) {
result = Integer.compare(attempts(b), attempts(a));
}
return result;
}

void incrementAttempts(Entry entry) {
attempts.compute(key(entry), (key, n) -> {
if (n == null) {
registerForKeyRemoval(entry, key);
return 1;
}
return n + 1;
});
}

private int attempts(Entry entry) {
return attempts.getOrDefault(key(entry), 0);
}

@SuppressWarnings("FutureReturnValueIgnored")
private void registerForKeyRemoval(Entry entry, UniqueId key) {
entry.future.whenComplete((__, ___) -> attempts.remove(key));
}

private static UniqueId key(Entry entry) {
return entry.task.getTestDescriptor().getUniqueId();
}
}

}

static class WorkerLeaseManager {
Expand Down Expand Up @@ -699,6 +664,14 @@ void reacquire() throws InterruptedException {
LOGGER.trace(() -> "reacquired worker lease (available: %d)".formatted(semaphore.availablePermits()));
}
}

@Override
public String toString() {
return new ToStringBuilder(this) //
.append("parallelism", parallelism) //
.append("semaphore", semaphore) //
.toString();
}
}

static class WorkerLease implements AutoCloseable {
Expand Down Expand Up @@ -735,4 +708,19 @@ void reacquire() throws InterruptedException {
reacquisitionToken = null;
}
}

private record LeaseAwareRejectedExecutionHandler(WorkerLeaseManager workerLeaseManager)
implements RejectedExecutionHandler {
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
if (!(r instanceof RunLeaseAwareWorker worker)) {
return;
}
worker.workerLease.release(false);
if (executor.isShutdown() || workerLeaseManager.isAtLeastOneLeaseTaken()) {
return;
}
throw new RejectedExecutionException("Task with " + workerLeaseManager + " rejected from " + executor);
}
}
}
Loading