Skip to content

Commit

Permalink
Default to VirtualThreads when Executor not provided (#813)
Browse files Browse the repository at this point in the history
Make executor and parallelism parameters optional and default to virtual
threads when not provided.
  • Loading branch information
pivovarit authored Jan 20, 2024
1 parent bb8ccd2 commit 5d9450c
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 18 deletions.
14 changes: 14 additions & 0 deletions src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
Expand Down Expand Up @@ -83,6 +84,12 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
return combined;
}

static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper) {
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), Function.identity());
}

static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
Expand All @@ -93,6 +100,13 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity());
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper) {
requireNonNull(collector, "collector can't be null");
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(),s -> s.collect(collector));
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(collector, "collector can't be null");
requireNonNull(executor, "executor can't be null");
Expand Down
14 changes: 9 additions & 5 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ final class Dispatcher<T> {
private final Executor executor;
private final Semaphore limiter;

private Dispatcher(int permits) {
private Dispatcher() {
this.executor = Executors.newVirtualThreadPerTaskExecutor();
this.limiter = new Semaphore(permits);
this.limiter = null;
}

private Dispatcher(Executor executor, int permits) {
Expand All @@ -31,8 +31,8 @@ static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
}

static <T> Dispatcher<T> virtual(int permits) {
return new Dispatcher<>(permits);
static <T> Dispatcher<T> virtual() {
return new Dispatcher<>();
}

CompletableFuture<T> enqueue(Supplier<T> supplier) {
Expand All @@ -51,7 +51,11 @@ private FutureTask<T> completionTask(Supplier<T> supplier, InterruptibleCompleta
FutureTask<T> task = new FutureTask<>(() -> {
if (!completionSignaller.isCompletedExceptionally()) {
try {
withLimiter(supplier, future);
if (limiter == null) {
future.complete(supplier.get());
} else {
withLimiter(supplier, future);
}
} catch (Throwable e) {
completionSignaller.completeExceptionally(e);
}
Expand Down
104 changes: 104 additions & 0 deletions src/main/java/com/pivovarit/collectors/ParallelCollectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@ public final class ParallelCollectors {
private ParallelCollectors() {
}


/**
* A convenience {@link Collector} used for executing parallel computations using Virtual Threads
* and returning them as a {@link CompletableFuture} containing a result of the application of the user-provided {@link Collector}.
*
* <br>
* Example:
* <pre>{@code
* CompletableFuture<List<String>> result = Stream.of(1, 2, 3)
* .collect(parallel(i -> foo(i), toList()));
* }</pre>
*
* @param mapper a transformation to be performed in parallel
* @param collector the {@code Collector} describing the reduction
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
* @param <RR> the reduction result {@code collector}
*
* @return a {@code Collector} which collects all processed elements into a user-provided mutable {@code Collection} in parallel
*
* @since 3.0.0
*/
public static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> parallel(Function<T, R> mapper, Collector<R, ?, RR> collector) {
return AsyncParallelCollector.collectingWithCollector(collector, mapper);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning them as a {@link CompletableFuture} containing a result of the application of the user-provided {@link Collector}.
Expand Down Expand Up @@ -44,6 +70,32 @@ private ParallelCollectors() {
return AsyncParallelCollector.collectingWithCollector(collector, mapper, executor, parallelism);
}

/**
* A convenience {@link Collector} used for executing parallel computations using Virtual Threads
* and returning them as {@link CompletableFuture} containing a {@link Stream} of these elements.
*
* <br><br>
* The collector maintains the order of processed {@link Stream}. Instances should not be reused.
*
* <br>
* Example:
* <pre>{@code
* CompletableFuture<Stream<String>> result = Stream.of(1, 2, 3)
* .collect(parallel(i -> foo()));
* }</pre>
*
* @param mapper a transformation to be performed in parallel
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
*
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel
*
* @since 3.0.0
*/
public static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> parallel(Function<T, R> mapper) {
return AsyncParallelCollector.collectingToStream(mapper);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning them as {@link CompletableFuture} containing a {@link Stream} of these elements.
Expand Down Expand Up @@ -72,6 +124,32 @@ private ParallelCollectors() {
return AsyncParallelCollector.collectingToStream(mapper, executor, parallelism);
}

/**
* A convenience {@link Collector} used for executing parallel computations using Virtual Threads
* and returning a {@link Stream} instance returning results as they arrive.
* <p>
* For the parallelism of 1, the stream is executed by the calling thread.
*
* <br>
* Example:
* <pre>{@code
* Stream.of(1, 2, 3)
* .collect(parallelToStream(i -> foo()))
* .forEach(System.out::println);
* }</pre>
*
* @param mapper a transformation to be performed in parallel
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
*
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel
*
* @since 3.0.0
*/
public static <T, R> Collector<T, ?, Stream<R>> parallelToStream(Function<T, R> mapper) {
return ParallelStreamCollector.streaming(mapper);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive.
Expand Down Expand Up @@ -100,6 +178,32 @@ private ParallelCollectors() {
return ParallelStreamCollector.streaming(mapper, executor, parallelism);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive while maintaining the initial order.
* <p>
* For the parallelism of 1, the stream is executed by the calling thread.
*
* <br>
* Example:
* <pre>{@code
* Stream.of(1, 2, 3)
* .collect(parallelToOrderedStream(i -> foo()))
* .forEach(System.out::println);
* }</pre>
*
* @param mapper a transformation to be performed in parallel
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
*
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel
*
* @since 3.0.0
*/
public static <T, R> Collector<T, ?, Stream<R>> parallelToOrderedStream(Function<T, R> mapper) {
return ParallelStreamCollector.streamingOrdered(mapper);
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive while maintaining the initial order.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ public Set<Characteristics> characteristics() {
return characteristics;
}

static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper) {
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.virtual());
}

static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
Expand All @@ -86,6 +92,12 @@ public Set<Characteristics> characteristics() {
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper) {
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.virtual());
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor,
int parallelism) {
requireNonNull(executor, "executor can't be null");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.pivovarit.collectors;

import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
Expand All @@ -11,10 +10,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static java.time.Duration.ofMillis;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down
54 changes: 52 additions & 2 deletions src/test/java/com/pivovarit/collectors/FunctionalTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.time.Duration;
import java.time.LocalTime;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
Expand Down Expand Up @@ -64,6 +63,12 @@ class FunctionalTest {
@TestFactory
Stream<DynamicTest> collectors() {
return of(
// virtual threads
virtualThreadsTests((m, e, p) -> parallel(m, toList()), "ParallelCollectors.parallel(toList()) [virtual]", true),
virtualThreadsTests((m, e, p) -> parallel(m, toSet()), "ParallelCollectors.parallel(toSet()) [virtual]", false),
virtualThreadsTests((m, e, p) -> parallel(m, toCollection(LinkedList::new)), "ParallelCollectors.parallel(toCollection()) [virtual]", true),
virtualThreadsTests((m, e, p) -> adapt(parallel(m)), "ParallelCollectors.parallel() [virtual]", true),
// platform threads
tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM), false),
tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM), true),
Expand All @@ -84,6 +89,10 @@ Stream<DynamicTest> batching_collectors() {
@TestFactory
Stream<DynamicTest> streaming_collectors() {
return of(
// virtual threads
virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToStream(m)), "ParallelCollectors.parallelToStream() [virtual]", false),
virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m)), "ParallelCollectors.parallelToOrderedStream() [virtual]", true),
// platform threads
streamingTests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM), false),
streamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM), true)
).flatMap(i -> i);
Expand All @@ -92,6 +101,10 @@ Stream<DynamicTest> streaming_collectors() {
@TestFactory
Stream<DynamicTest> streaming_batching_collectors() {
return of(
// virtual threads
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), "ParallelCollectors.Batching.parallelToStream() [virtual]", false),
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), "ParallelCollectors.Batching.parallelToOrderedStream(p=%d) [virtual]", true),
// platform threads
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToStream(m, e, p)), format("ParallelCollectors.Batching.parallelToStream(p=%d)", PARALLELISM), false),
batchStreamingTests((m, e, p) -> adaptAsync(Batching.parallelToOrderedStream(m, e, p)), format("ParallelCollectors.Batching.parallelToOrderedStream(p=%d)", PARALLELISM), true)
).flatMap(i -> i);
Expand Down Expand Up @@ -150,6 +163,26 @@ void shouldExecuteEagerlyOnProvidedThreadPool() {
}
}

private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThreadsTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
var tests = of(
shouldCollect(collector, name, 1),
shouldCollect(collector, name, PARALLELISM),
shouldCollectNElementsWithNParallelism(collector, name, 1),
shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM),
shouldCollectToEmpty(collector, name),
shouldStartConsumingImmediately(collector, name),
shouldNotBlockTheCallingThread(collector, name),
shouldHandleThrowable(collector, name),
shouldShortCircuitOnException(collector, name),
shouldInterruptOnException(collector, name),
shouldRemainConsistent(collector, name)
);

return maintainsOrder
? Stream.concat(tests, of(shouldMaintainOrder(collector, name)))
: tests;
}

private static <R extends Collection<Integer>> Stream<DynamicTest> tests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
var tests = of(
shouldCollect(collector, name, 1),
Expand All @@ -174,6 +207,23 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> tests(Collect
: tests;
}

private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThreadsStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
var tests = of(
shouldCollect(collector, name, 1),
shouldCollect(collector, name, PARALLELISM),
shouldCollectToEmpty(collector, name),
shouldStartConsumingImmediately(collector, name),
shouldNotBlockTheCallingThread(collector, name),
shouldHandleThrowable(collector, name),
shouldShortCircuitOnException(collector, name),
shouldRemainConsistent(collector, name)
);

return maintainsOrder
? Stream.concat(tests, of(shouldMaintainOrder(collector, name)))
: tests;
}

private static <R extends Collection<Integer>> Stream<DynamicTest> streamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean maintainsOrder) {
var tests = of(
shouldCollect(collector, name, 1),
Expand Down Expand Up @@ -306,7 +356,7 @@ private static <R extends Collection<Integer>> DynamicTest shouldShortCircuitOnE
int size = 4;

runWithExecutor(e -> {
LongAdder counter = new LongAdder();
AtomicInteger counter = new AtomicInteger();

assertThatThrownBy(elements.stream()
.collect(collector.apply(i -> incrementAndThrow(counter), e, PARALLELISM))::join)
Expand Down
14 changes: 6 additions & 8 deletions src/test/java/com/pivovarit/collectors/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

public final class TestUtils {
Expand Down Expand Up @@ -36,14 +36,12 @@ public static <T> T returnWithDelay(T value, Duration duration) {
return value;
}

public static Integer incrementAndThrow(LongAdder counter) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
// ignore purposefully
public static Integer incrementAndThrow(AtomicInteger counter) {
if (counter.incrementAndGet() == 10) {
throw new IllegalArgumentException();
}
counter.increment();
throw new IllegalArgumentException();

return counter.intValue();
}

public static void runWithExecutor(Consumer<Executor> consumer, int size) {
Expand Down

0 comments on commit 5d9450c

Please sign in to comment.