Skip to content

Commit

Permalink
Dispatcher to use caller thread instead of dedicated scheduler thread (
Browse files Browse the repository at this point in the history
…#789)

Remove the internal single-thread scheduler and rely on the caller
thread to submit all relevant tasks to a given thread pool. This not
only simplified the solution, but also:
- helped avoid context propagation issues when execution switches
between multiple threads
- made the tool more Loom-friendly since instances of
`ParallelCollectors` do not create their own threads
  • Loading branch information
pivovarit authored Oct 12, 2023
1 parent 2011ed5 commit c48c915
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 259 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ For example:

List<String> result = list.parallelStream()
.map(i -> foo(i)) // runs implicitly on ForkJoinPool.commonPool()
.collect(Collectors.toList());
.toList();

In order to avoid such problems, **the solution is to isolate blocking tasks** and run them on a separate thread pool... but there's a catch.

Expand Down
13 changes: 2 additions & 11 deletions src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,12 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> {
if (!dispatcher.isRunning()) {
dispatcher.start();
}
acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
};
return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e)));
}

@Override
public Function<List<CompletableFuture<R>>, CompletableFuture<C>> finisher() {
return futures -> {
dispatcher.stop();

return combine(futures).thenApply(processor);
};
return futures -> combine(futures).thenApply(processor);
}

@Override
Expand Down
130 changes: 34 additions & 96 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,41 +1,24 @@
package com.pivovarit.collectors;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

/**
* @author Grzegorz Piwowarek
*/
final class Dispatcher<T> {

private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?");

private final CompletableFuture<Void> completionSignaller = new CompletableFuture<>();

private final BlockingQueue<Runnable> workingQueue = new LinkedBlockingQueue<>();

private final ExecutorService dispatcher = newLazySingleThreadExecutor();
private final Executor executor;
private final Semaphore limiter;

private final AtomicBoolean started = new AtomicBoolean(false);

private volatile boolean shortCircuited = false;

private Dispatcher(int permits) {
this.executor = defaultExecutorService();
this.executor = Executors.newVirtualThreadPerTaskExecutor();
this.limiter = new Semaphore(permits);
}

Expand All @@ -52,110 +35,65 @@ static <T> Dispatcher<T> virtual(int permits) {
return new Dispatcher<>(permits);
}

void start() {
if (!started.getAndSet(true)) {
dispatcher.execute(() -> {
try {
while (true) {
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
executor.execute(() -> {
try {
limiter.acquire();
task.run();
} catch (InterruptedException e) {
handle(e);
} finally {
limiter.release();
}
});
} else {
break;
}
}
} catch (Throwable e) {
handle(e);
}
});
}
}

void stop() {
CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
completionSignaller.whenComplete(shortcircuit(future));
try {
workingQueue.put(POISON_PILL);
} catch (InterruptedException e) {
executor.execute(completionTask(supplier, future));
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
} finally {
dispatcher.shutdown();
return CompletableFuture.failedFuture(e);
}
}

boolean isRunning() {
return started.get();
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
InterruptibleCompletableFuture<T> future = new InterruptibleCompletableFuture<>();
workingQueue.add(completionTask(supplier, future));
completionSignaller.exceptionally(shortcircuit(future));
return future;
}

private FutureTask<Void> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<Void> task = new FutureTask<>(() -> {
try {
if (!shortCircuited) {
future.complete(supplier.get());
private FutureTask<T> completionTask(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) {
FutureTask<T> task = new FutureTask<>(() -> {
if (!completionSignaller.isCompletedExceptionally()) {
try {
withLimiter(supplier, future);
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
}
} catch (Throwable e) {
handle(e);
}
}, null);
future.completedBy(task);
return task;
}

private void handle(Throwable e) {
shortCircuited = true;
completionSignaller.completeExceptionally(e);
dispatcher.shutdownNow();
private void withLimiter(Supplier<T> supplier, InterruptibleCompletableFuture<T> future) throws InterruptedException {
try {
limiter.acquire();
future.complete(supplier.get());
} finally {
limiter.release();
}
}

private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFuture<?> future) {
return throwable -> {
future.completeExceptionally(throwable);
future.cancel(true);
return null;
private static <T> BiConsumer<T, Throwable> shortcircuit(InterruptibleCompletableFuture<?> future) {
return (__, throwable) -> {
if (throwable != null) {
future.completeExceptionally(throwable);
future.cancel(true);
}
};
}

private static ThreadPoolExecutor newLazySingleThreadExecutor() {
return new ThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
new SynchronousQueue<>(), // dispatcher always executes a single task
Thread.ofPlatform()
.name("parallel-collectors-dispatcher-", 0)
.daemon(false)
.factory());
}

static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {

private volatile FutureTask<?> backingTask;
private void completedBy(FutureTask<Void> task) {
private volatile FutureTask<T> backingTask;

private void completedBy(FutureTask<T> task) {
backingTask = task;
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
if (backingTask != null) {
backingTask.cancel(mayInterruptIfRunning);
var task = backingTask;
if (task != null) {
task.cancel(mayInterruptIfRunning);
}
return super.cancel(mayInterruptIfRunning);
}

}
private static ExecutorService defaultExecutorService() {
return Executors.newVirtualThreadPerTaskExecutor();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ public Supplier<List<CompletableFuture<R>>> supplier() {

@Override
public BiConsumer<List<CompletableFuture<R>>, T> accumulator() {
return (acc, e) -> {
dispatcher.start();
acc.add(dispatcher.enqueue(() -> function.apply(e)));
};
return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e)));
}

@Override
Expand All @@ -73,10 +70,7 @@ public BinaryOperator<List<CompletableFuture<R>>> combiner() {

@Override
public Function<List<CompletableFuture<R>>, Stream<R>> finisher() {
return acc -> {
dispatcher.stop();
return completionStrategy.apply(acc);
};
return completionStrategy;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.pivovarit.collectors;

import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
Expand Down Expand Up @@ -35,9 +36,7 @@ void shouldTraverseInCompletionOrder() {
sleep(100);
f2.complete(1);
});
List<Integer> results = StreamSupport.stream(
new CompletionOrderSpliterator<>(futures), false)
.collect(Collectors.toList());
var results = StreamSupport.stream(new CompletionOrderSpliterator<>(futures), false).toList();

assertThat(results).containsExactly(3, 2, 1);
}
Expand All @@ -56,9 +55,7 @@ void shouldPropagateException() {
sleep(100);
f2.complete(1);
});
assertThatThrownBy(() -> StreamSupport.stream(
new CompletionOrderSpliterator<>(futures), false)
.collect(Collectors.toList()))
assertThatThrownBy(() -> StreamSupport.stream(new CompletionOrderSpliterator<>(futures), false).toList())
.isInstanceOf(CompletionException.class)
.hasCauseExactlyInstanceOf(RuntimeException.class);
}
Expand Down Expand Up @@ -96,26 +93,23 @@ void shouldNotConsumeOnEmpty() {
}

@Test
void shouldRestoreInterrupt() throws InterruptedException {
void shouldRestoreInterrupt() {
Thread executorThread = new Thread(() -> {
Spliterator<Integer> spliterator = new CompletionOrderSpliterator<>(Arrays.asList(new CompletableFuture<>()));
try {
spliterator.tryAdvance(i -> {});
} catch (Exception e) {
while (true) {

Thread.onSpinWait();
}
}
});

executorThread.start();

Thread.sleep(100);

executorThread.interrupt();

await()
.pollDelay(ofMillis(100))
.until(executorThread::isInterrupted);
}

Expand Down
Loading

0 comments on commit c48c915

Please sign in to comment.