From c48c915beeeee8ccd48137b20313668ecf0aa192 Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Thu, 12 Oct 2023 08:48:20 +0200 Subject: [PATCH] Dispatcher to use caller thread instead of dedicated scheduler thread (#789) Remove the internal single-thread scheduler and rely on the caller thread to submit all relevant tasks to a given thread pool. This not only simplified the solution, but also: - helped avoid context propagation issues when execution switches between multiple threads - made the tool more Loom-friendly since instances of `ParallelCollectors` do not create their own threads --- README.md | 2 +- .../collectors/AsyncParallelCollector.java | 13 +- .../com/pivovarit/collectors/Dispatcher.java | 130 +++-------- .../collectors/ParallelStreamCollector.java | 10 +- .../CompletionOrderSpliteratorTest.java | 16 +- .../pivovarit/collectors/FunctionalTest.java | 207 +++++++++--------- .../collectors/FutureCollectorsTest.java | 49 +++-- .../com/pivovarit/collectors/TestUtils.java | 15 ++ .../pivovarit/collectors/benchmark/Bench.java | 6 +- 9 files changed, 189 insertions(+), 259 deletions(-) diff --git a/README.md b/README.md index 8cd2b858..8f8ca2ab 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ For example: List result = list.parallelStream() .map(i -> foo(i)) // runs implicitly on ForkJoinPool.commonPool() - .collect(Collectors.toList()); + .toList(); In order to avoid such problems, **the solution is to isolate blocking tasks** and run them on a separate thread pool... but there's a catch. diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index 7411ab6d..e2d4b48a 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -55,21 +55,12 @@ public BinaryOperator>> combiner() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> { - if (!dispatcher.isRunning()) { - dispatcher.start(); - } - acc.add(dispatcher.enqueue(() -> mapper.apply(e))); - }; + return (acc, e) -> acc.add(dispatcher.enqueue(() -> mapper.apply(e))); } @Override public Function>, CompletableFuture> finisher() { - return futures -> { - dispatcher.stop(); - - return combine(futures).thenApply(processor); - }; + return futures -> combine(futures).thenApply(processor); } @Override diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index 87bfba5c..23a9aad6 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -1,18 +1,11 @@ package com.pivovarit.collectors; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.FutureTask; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; +import java.util.function.BiConsumer; import java.util.function.Supplier; /** @@ -20,22 +13,12 @@ */ final class Dispatcher { - private static final Runnable POISON_PILL = () -> System.out.println("Why so serious?"); - private final CompletableFuture completionSignaller = new CompletableFuture<>(); - - private final BlockingQueue workingQueue = new LinkedBlockingQueue<>(); - - private final ExecutorService dispatcher = newLazySingleThreadExecutor(); private final Executor executor; private final Semaphore limiter; - private final AtomicBoolean started = new AtomicBoolean(false); - - private volatile boolean shortCircuited = false; - private Dispatcher(int permits) { - this.executor = defaultExecutorService(); + this.executor = Executors.newVirtualThreadPerTaskExecutor(); this.limiter = new Semaphore(permits); } @@ -52,110 +35,65 @@ static Dispatcher virtual(int permits) { return new Dispatcher<>(permits); } - void start() { - if (!started.getAndSet(true)) { - dispatcher.execute(() -> { - try { - while (true) { - Runnable task; - if ((task = workingQueue.take()) != POISON_PILL) { - executor.execute(() -> { - try { - limiter.acquire(); - task.run(); - } catch (InterruptedException e) { - handle(e); - } finally { - limiter.release(); - } - }); - } else { - break; - } - } - } catch (Throwable e) { - handle(e); - } - }); - } - } - - void stop() { + CompletableFuture enqueue(Supplier supplier) { + InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); + completionSignaller.whenComplete(shortcircuit(future)); try { - workingQueue.put(POISON_PILL); - } catch (InterruptedException e) { + executor.execute(completionTask(supplier, future)); + } catch (Throwable e) { completionSignaller.completeExceptionally(e); - } finally { - dispatcher.shutdown(); + return CompletableFuture.failedFuture(e); } - } - - boolean isRunning() { - return started.get(); - } - - CompletableFuture enqueue(Supplier supplier) { - InterruptibleCompletableFuture future = new InterruptibleCompletableFuture<>(); - workingQueue.add(completionTask(supplier, future)); - completionSignaller.exceptionally(shortcircuit(future)); return future; } - private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { - FutureTask task = new FutureTask<>(() -> { - try { - if (!shortCircuited) { - future.complete(supplier.get()); + private FutureTask completionTask(Supplier supplier, InterruptibleCompletableFuture future) { + FutureTask task = new FutureTask<>(() -> { + if (!completionSignaller.isCompletedExceptionally()) { + try { + withLimiter(supplier, future); + } catch (Throwable e) { + completionSignaller.completeExceptionally(e); } - } catch (Throwable e) { - handle(e); } }, null); future.completedBy(task); return task; } - private void handle(Throwable e) { - shortCircuited = true; - completionSignaller.completeExceptionally(e); - dispatcher.shutdownNow(); + private void withLimiter(Supplier supplier, InterruptibleCompletableFuture future) throws InterruptedException { + try { + limiter.acquire(); + future.complete(supplier.get()); + } finally { + limiter.release(); + } } - private static Function shortcircuit(InterruptibleCompletableFuture future) { - return throwable -> { - future.completeExceptionally(throwable); - future.cancel(true); - return null; + private static BiConsumer shortcircuit(InterruptibleCompletableFuture future) { + return (__, throwable) -> { + if (throwable != null) { + future.completeExceptionally(throwable); + future.cancel(true); + } }; } - private static ThreadPoolExecutor newLazySingleThreadExecutor() { - return new ThreadPoolExecutor(1, 1, - 0L, TimeUnit.MILLISECONDS, - new SynchronousQueue<>(), // dispatcher always executes a single task - Thread.ofPlatform() - .name("parallel-collectors-dispatcher-", 0) - .daemon(false) - .factory()); - } - static final class InterruptibleCompletableFuture extends CompletableFuture { - private volatile FutureTask backingTask; - private void completedBy(FutureTask task) { + private volatile FutureTask backingTask; + + private void completedBy(FutureTask task) { backingTask = task; } @Override public boolean cancel(boolean mayInterruptIfRunning) { - if (backingTask != null) { - backingTask.cancel(mayInterruptIfRunning); + var task = backingTask; + if (task != null) { + task.cancel(mayInterruptIfRunning); } return super.cancel(mayInterruptIfRunning); } - - } - private static ExecutorService defaultExecutorService() { - return Executors.newVirtualThreadPerTaskExecutor(); } } diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 24d99cab..4902946d 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -57,10 +57,7 @@ public Supplier>> supplier() { @Override public BiConsumer>, T> accumulator() { - return (acc, e) -> { - dispatcher.start(); - acc.add(dispatcher.enqueue(() -> function.apply(e))); - }; + return (acc, e) -> acc.add(dispatcher.enqueue(() -> function.apply(e))); } @Override @@ -73,10 +70,7 @@ public BinaryOperator>> combiner() { @Override public Function>, Stream> finisher() { - return acc -> { - dispatcher.stop(); - return completionStrategy.apply(acc); - }; + return completionStrategy; } @Override diff --git a/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java b/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java index 2fecda94..7a04dd3a 100644 --- a/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java +++ b/src/test/java/com/pivovarit/collectors/CompletionOrderSpliteratorTest.java @@ -1,5 +1,6 @@ package com.pivovarit.collectors; +import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -35,9 +36,7 @@ void shouldTraverseInCompletionOrder() { sleep(100); f2.complete(1); }); - List results = StreamSupport.stream( - new CompletionOrderSpliterator<>(futures), false) - .collect(Collectors.toList()); + var results = StreamSupport.stream(new CompletionOrderSpliterator<>(futures), false).toList(); assertThat(results).containsExactly(3, 2, 1); } @@ -56,9 +55,7 @@ void shouldPropagateException() { sleep(100); f2.complete(1); }); - assertThatThrownBy(() -> StreamSupport.stream( - new CompletionOrderSpliterator<>(futures), false) - .collect(Collectors.toList())) + assertThatThrownBy(() -> StreamSupport.stream(new CompletionOrderSpliterator<>(futures), false).toList()) .isInstanceOf(CompletionException.class) .hasCauseExactlyInstanceOf(RuntimeException.class); } @@ -96,26 +93,23 @@ void shouldNotConsumeOnEmpty() { } @Test - void shouldRestoreInterrupt() throws InterruptedException { + void shouldRestoreInterrupt() { Thread executorThread = new Thread(() -> { Spliterator spliterator = new CompletionOrderSpliterator<>(Arrays.asList(new CompletableFuture<>())); try { spliterator.tryAdvance(i -> {}); } catch (Exception e) { while (true) { - + Thread.onSpinWait(); } } }); executorThread.start(); - Thread.sleep(100); - executorThread.interrupt(); await() - .pollDelay(ofMillis(100)) .until(executorThread::isInterrupted); } diff --git a/src/test/java/com/pivovarit/collectors/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/FunctionalTest.java index 3172a5a2..0e2b1164 100644 --- a/src/test/java/com/pivovarit/collectors/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/FunctionalTest.java @@ -5,7 +5,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestFactory; -import java.lang.reflect.Field; import java.time.Duration; import java.time.LocalTime; import java.util.Arrays; @@ -30,7 +29,6 @@ import java.util.concurrent.atomic.LongAdder; import java.util.function.Function; import java.util.stream.Collector; -import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -40,6 +38,7 @@ import static com.pivovarit.collectors.TestUtils.incrementAndThrow; import static com.pivovarit.collectors.TestUtils.returnWithDelay; import static com.pivovarit.collectors.TestUtils.runWithExecutor; +import static com.pivovarit.collectors.TestUtils.withExecutor; import static java.lang.String.format; import static java.time.Duration.ofMillis; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -61,7 +60,6 @@ class FunctionalTest { private static final int PARALLELISM = 1000; - private static Executor executor = Executors.newFixedThreadPool(100); @TestFactory Stream collectors() { @@ -102,23 +100,23 @@ Stream streaming_batching_collectors() { @Test void shouldCollectInCompletionOrder() { // given - executor = threadPoolExecutor(4); - - List result = Stream.of(350, 200, 0, 400) - .collect(parallelToStream(i -> returnWithDelay(i, ofMillis(i)), executor, 4)) - .limit(2) - .collect(toList()); + try (var executor = threadPoolExecutor(4)) { + List result = of(350, 200, 0, 400) + .collect(parallelToStream(i -> returnWithDelay(i, ofMillis(i)), executor, 4)) + .limit(2) + .toList(); - assertThat(result).isSorted(); + assertThat(result).isSorted(); + } } @Test void shouldCollectEagerlyInCompletionOrder() { // given - executor = threadPoolExecutor(4); + var executor = threadPoolExecutor(4); AtomicBoolean result = new AtomicBoolean(false); CompletableFuture.runAsync(() -> { - Stream.of(1, 10000, 1, 0) + of(1, 10000, 1, 0) .collect(parallelToStream(i -> returnWithDelay(i, ofMillis(i)), executor, 2)) .forEach(i -> { if (i == 1) { @@ -134,35 +132,32 @@ void shouldCollectEagerlyInCompletionOrder() { @Test void shouldExecuteEagerlyOnProvidedThreadPool() { - ExecutorService executor = Executors.newFixedThreadPool(2); - CountingExecutor countingExecutor = new CountingExecutor(executor); - AtomicInteger executions = new AtomicInteger(); - try { - List list = Arrays.asList("A", "B"); + try (var executor = Executors.newFixedThreadPool(2)) { + var countingExecutor = new CountingExecutor(executor); + var executions = new AtomicInteger(); + var list = List.of("A", "B"); - Stream stream = list.stream() + list.stream() .collect(parallel(s -> { executions.incrementAndGet(); return s; }, countingExecutor, 1)) - .join(); - } finally { - executor.shutdown(); - } + .join() + .forEach(__ -> {}); - assertThat(countingExecutor.getInvocations()).isEqualTo(1); - assertThat(executions.get()).isEqualTo(2); + assertThat(countingExecutor.getInvocations()).isEqualTo(1); + assertThat(executions.get()).isEqualTo(2); + } } private static > Stream tests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { - Stream tests = of( + var tests = of( shouldCollect(collector, name, 1), shouldCollect(collector, name, PARALLELISM), shouldCollectNElementsWithNParallelism(collector, name, 1), shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM), shouldCollectToEmpty(collector, name), shouldStartConsumingImmediately(collector, name), - shouldTerminateAfterConsumingAllElements(collector, name), shouldNotBlockTheCallingThread(collector, name), shouldRespectParallelism(collector, name), shouldHandleThrowable(collector, name), @@ -170,32 +165,33 @@ private static > Stream tests(Collect shouldInterruptOnException(collector, name), shouldHandleRejectedExecutionException(collector, name), shouldRemainConsistent(collector, name), - shouldRejectInvalidParallelism(collector, name) + shouldRejectInvalidParallelism(collector, name), + shouldHandleExecutorRejection(collector, name) ); return maintainsOrder - ? Stream.concat(tests, Stream.of(shouldMaintainOrder(collector, name))) + ? Stream.concat(tests, of(shouldMaintainOrder(collector, name))) : tests; } private static > Stream streamingTests(CollectorSupplier, Executor, Integer, Collector>> collector, String name, boolean maintainsOrder) { - Stream tests = of( + var tests = of( shouldCollect(collector, name, 1), shouldCollect(collector, name, PARALLELISM), shouldCollectToEmpty(collector, name), shouldStartConsumingImmediately(collector, name), - shouldTerminateAfterConsumingAllElements(collector, name), shouldNotBlockTheCallingThread(collector, name), shouldRespectParallelism(collector, name), shouldHandleThrowable(collector, name), shouldShortCircuitOnException(collector, name), shouldHandleRejectedExecutionException(collector, name), shouldRemainConsistent(collector, name), - shouldRejectInvalidParallelism(collector, name) + shouldRejectInvalidParallelism(collector, name), + shouldHandleExecutorRejection(collector, name) ); return maintainsOrder - ? Stream.concat(tests, Stream.of(shouldMaintainOrder(collector, name))) + ? Stream.concat(tests, of(shouldMaintainOrder(collector, name))) : tests; } @@ -213,16 +209,19 @@ private static > Stream batchStreamin private static > DynamicTest shouldNotBlockTheCallingThread(CollectorSupplier, Executor, Integer, Collector>> c, String name) { return dynamicTest(format("%s: should not block when returning future", name), () -> { - assertTimeoutPreemptively(ofMillis(100), () -> - Stream.empty().collect(c - .apply(i -> returnWithDelay(42, ofMillis(Integer.MAX_VALUE)), executor, 1)), "returned blocking future"); + withExecutor(e -> { + assertTimeoutPreemptively(ofMillis(100), () -> + Stream.empty().collect(c + .apply(i -> returnWithDelay(42, ofMillis(Integer.MAX_VALUE)), e, 1)), "returned blocking future"); + }); }); } private static > DynamicTest shouldCollectToEmpty(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should collect to empty", name), () -> { - assertThat(Stream.empty().collect(collector.apply(i -> i, executor, PARALLELISM)).join()) - .isEmpty(); + withExecutor(e -> { + assertThat(Stream.empty().collect(collector.apply(i -> i, e, PARALLELISM)).join()).isEmpty(); + }); }); } @@ -230,83 +229,64 @@ private static > DynamicTest shouldRespectParallel return dynamicTest(format("%s: should respect parallelism", name), () -> { int parallelism = 2; int delayMillis = 50; - executor = Executors.newCachedThreadPool(); - - LocalTime before = LocalTime.now(); - Stream.generate(() -> 42) - .limit(4) - .collect(collector.apply(i -> returnWithDelay(i, ofMillis(delayMillis)), executor, parallelism)) - .join(); + withExecutor(e -> { + LocalTime before = LocalTime.now(); + Stream.generate(() -> 42) + .limit(4) + .collect(collector.apply(i -> returnWithDelay(i, ofMillis(delayMillis)), e, parallelism)) + .join(); - LocalTime after = LocalTime.now(); - assertThat(Duration.between(before, after)) - .isGreaterThanOrEqualTo(Duration.ofMillis(delayMillis * parallelism)); + LocalTime after = LocalTime.now(); + assertThat(Duration.between(before, after)) + .isGreaterThanOrEqualTo(ofMillis(delayMillis * parallelism)); + }); }); } private static > DynamicTest shouldProcessOnNThreadsETParallelism(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should batch", name), () -> { int parallelism = 2; - executor = Executors.newFixedThreadPool(10); - Set threads = new ConcurrentSkipListSet<>(); + withExecutor(e -> { + Set threads = new ConcurrentSkipListSet<>(); - Stream.generate(() -> 42) - .limit(100) - .collect(collector.apply(i -> { - threads.add(Thread.currentThread().getName()); - return i; - }, executor, parallelism)) - .join(); + Stream.generate(() -> 42) + .limit(100) + .collect(collector.apply(i -> { + threads.add(Thread.currentThread().getName()); + return i; + }, e, parallelism)) + .join(); - assertThat(threads).hasSize(parallelism); + assertThat(threads).hasSize(parallelism); + }); }); } private static > DynamicTest shouldCollect(CollectorSupplier, Executor, Integer, Collector>> factory, String name, int parallelism) { return dynamicTest(format("%s: should collect with parallelism %s", name, parallelism), () -> { - List elements = IntStream.range(0, 10).boxed().collect(toList()); - Collector> ctor = factory.apply(i -> i, executor, parallelism); - Collection result = elements.stream().collect(ctor) - .join(); + var elements = IntStream.range(0, 10).boxed().toList(); - assertThat(result).hasSameElementsAs(elements); + withExecutor(e -> { + Collector> ctor = factory.apply(i -> i, e, parallelism); + Collection result = elements.stream().collect(ctor) + .join(); + + assertThat(result).hasSameElementsAs(elements); + }); }); } private static > DynamicTest shouldCollectNElementsWithNParallelism(CollectorSupplier, Executor, Integer, Collector>> factory, String name, int parallelism) { return dynamicTest(format("%s: should collect %s elements with parallelism %s", name, parallelism, parallelism), () -> { + var elements = IntStream.iterate(0, i -> i + 1).limit(parallelism).boxed().toList(); - List elements = IntStream.iterate(0, i -> i + 1).limit(parallelism).boxed().collect(toList()); - Collector> ctor = factory.apply(i -> i, executor, parallelism); - Collection result = elements.stream().collect(ctor) - .join(); - - assertThat(result).hasSameElementsAs(elements); - }); - } - - private static > DynamicTest shouldTerminateAfterConsumingAllElements(CollectorSupplier, Executor, Integer, Collector>> factory, String name) { - return dynamicTest(format("%s: should terminate after consuming all elements", name), () -> { - List elements = IntStream.range(0, 10).boxed().collect(toList()); - Collector> ctor = factory.apply(i -> i, executor, 10); - Collection result = elements.stream().collect(ctor) - .join(); - - assertThat(result).hasSameElementsAs(elements); + withExecutor(e -> { + Collector> ctor = factory.apply(i -> i, e, parallelism); + Collection result = elements.stream().collect(ctor).join(); - if (ctor instanceof AsyncParallelCollector) { - Field dispatcherField = AsyncParallelCollector.class.getDeclaredField("dispatcher"); - dispatcherField.setAccessible(true); - Dispatcher dispatcher = (Dispatcher) dispatcherField.get(ctor); - Field innerDispatcherField = Dispatcher.class.getDeclaredField("dispatcher"); - innerDispatcherField.setAccessible(true); - ExecutorService executor = (ExecutorService) innerDispatcherField.get(dispatcher); - - await() - .atMost(Duration.ofSeconds(2)) - .until(executor::isTerminated); - } + assertThat(result).hasSameElementsAs(elements); + }); }); } @@ -359,7 +339,7 @@ private static > DynamicTest shouldHandleThrowable private static > DynamicTest shouldHandleRejectedExecutionException(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should propagate rejected execution exception", name), () -> { - Executor executor = command -> { throw new RejectedExecutionException(); }; + Executor executor = command -> {throw new RejectedExecutionException();}; List elements = IntStream.range(0, 1000).boxed().toList(); assertThatThrownBy(() -> elements.stream() @@ -381,7 +361,7 @@ private static > DynamicTest shouldRemainConsisten ExecutorService executor = Executors.newFixedThreadPool(parallelism); try { - List elements = IntStream.range(0, parallelism).boxed().collect(toList()); + List elements = IntStream.range(0, parallelism).boxed().toList(); CountDownLatch countDownLatch = new CountDownLatch(parallelism); @@ -408,23 +388,40 @@ private static > DynamicTest shouldRemainConsisten private static > DynamicTest shouldRejectInvalidParallelism(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should reject invalid parallelism", name), () -> { - assertThatThrownBy(() -> collector.apply(i -> i, executor, -1)) - .isExactlyInstanceOf(IllegalArgumentException.class); + withExecutor(e -> { + assertThatThrownBy(() -> collector.apply(i -> i, e, -1)) + .isExactlyInstanceOf(IllegalArgumentException.class); + }); + }); + } + + private static > DynamicTest shouldHandleExecutorRejection(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { + return dynamicTest(format("%s: should handle rejected execution", name), () -> { + assertThatThrownBy(() -> { + try (var e = new ThreadPoolExecutor(2, 2, 0L, MILLISECONDS, + new LinkedBlockingQueue<>(1), new ThreadPoolExecutor.AbortPolicy())) { + assertTimeoutPreemptively(ofMillis(100), () -> of(1, 2, 3, 4) + .collect(collector.apply(i -> TestUtils.sleepAndReturn(1_000, i), e, Integer.MAX_VALUE)) + .join()); + } + }).isExactlyInstanceOf(CompletionException.class); }); } private static > DynamicTest shouldStartConsumingImmediately(CollectorSupplier, Executor, Integer, Collector>> collector, String name) { return dynamicTest(format("%s: should start consuming immediately", name), () -> { - AtomicInteger counter = new AtomicInteger(); + try (var e = Executors.newCachedThreadPool()) { + var counter = new AtomicInteger(); - Stream.iterate(0, i -> returnWithDelay(i + 1, ofMillis(100))) - .limit(2) - .collect(collector.apply(i -> counter.incrementAndGet(), executor, PARALLELISM)); + Stream.iterate(0, i -> returnWithDelay(i + 1, ofMillis(100))) + .limit(2) + .collect(collector.apply(i -> counter.incrementAndGet(), e, PARALLELISM)); - await() - .pollInterval(Duration.ofMillis(10)) - .atMost(50, MILLISECONDS) - .until(() -> counter.get() > 0); + await() + .pollInterval(ofMillis(10)) + .atMost(50, MILLISECONDS) + .until(() -> counter.get() > 0); + } }); } @@ -455,12 +452,12 @@ private static > DynamicTest shouldInterruptOnExce } private static Collector>> adapt(Collector>> input) { - return collectingAndThen(input, stream -> stream.thenApply(s -> s.collect(Collectors.toList()))); + return collectingAndThen(input, stream -> stream.thenApply(Stream::toList)); } private static Collector>> adaptAsync(Collector> input) { return collectingAndThen(input, stream -> CompletableFuture - .supplyAsync(() -> stream.collect(toList()), Executors.newSingleThreadExecutor())); + .supplyAsync(stream::toList, Executors.newSingleThreadExecutor())); } private static ThreadPoolExecutor threadPoolExecutor(int unitsOfWork) { diff --git a/src/test/java/com/pivovarit/collectors/FutureCollectorsTest.java b/src/test/java/com/pivovarit/collectors/FutureCollectorsTest.java index bba05a42..b758f597 100644 --- a/src/test/java/com/pivovarit/collectors/FutureCollectorsTest.java +++ b/src/test/java/com/pivovarit/collectors/FutureCollectorsTest.java @@ -30,9 +30,9 @@ void shouldCollect() { @Test void shouldCollectToList() { - List list = Arrays.asList(1, 2, 3); + var list = Arrays.asList(1, 2, 3); - CompletableFuture> result = list.stream() + var result = list.stream() .map(i -> CompletableFuture.supplyAsync(() -> i)) .collect(ParallelCollectors.toFuture(toList())); @@ -41,30 +41,31 @@ void shouldCollectToList() { @Test void shouldShortcircuit() { - List list = IntStream.range(0, 10).boxed().collect(toList()); - - ExecutorService e = Executors.newFixedThreadPool(10); + var list = IntStream.range(0, 10).boxed().toList(); - CompletableFuture> result = list.stream() - .map(i -> CompletableFuture.supplyAsync(() -> { - if (i != 9) { - try { - Thread.sleep(1000); - } catch (InterruptedException ex) { - ex.printStackTrace(); + try (var e = Executors.newFixedThreadPool(10)) { + CompletableFuture> result + = list.stream() + .map(i -> CompletableFuture.supplyAsync(() -> { + if (i != 9) { + try { + Thread.sleep(1000); + } catch (InterruptedException ex) { + ex.printStackTrace(); + } + return i; + } else { + throw new RuntimeException(); } - return i; - } else { - throw new RuntimeException(); - } - }, e)) - .collect(ParallelCollectors.toFuture(toList())); + }, e)) + .collect(ParallelCollectors.toFuture(toList())); - assertTimeout(Duration.ofMillis(100), () -> { - try { - result.join(); - } catch (CompletionException ex) { - } - }); + assertTimeout(Duration.ofMillis(100), () -> { + try { + result.join(); + } catch (CompletionException ex) { + } + }); + } } } diff --git a/src/test/java/com/pivovarit/collectors/TestUtils.java b/src/test/java/com/pivovarit/collectors/TestUtils.java index 4e7179bf..0105fa1b 100644 --- a/src/test/java/com/pivovarit/collectors/TestUtils.java +++ b/src/test/java/com/pivovarit/collectors/TestUtils.java @@ -11,6 +11,21 @@ public final class TestUtils { private TestUtils() { } + public static void withExecutor(Consumer consumer) { + try (var executorService = Executors.newCachedThreadPool()) { + consumer.accept(executorService); + } + } + + public static T sleepAndReturn(int millis, T value) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return value; + } + public static T returnWithDelay(T value, Duration duration) { try { Thread.sleep(duration.toMillis()); diff --git a/src/test/java/com/pivovarit/collectors/benchmark/Bench.java b/src/test/java/com/pivovarit/collectors/benchmark/Bench.java index a642e11d..e4fbd15a 100644 --- a/src/test/java/com/pivovarit/collectors/benchmark/Bench.java +++ b/src/test/java/com/pivovarit/collectors/benchmark/Bench.java @@ -42,7 +42,7 @@ public void tearDown() { private static final List source = IntStream.range(0, 1000) .boxed() - .collect(toList()); + .toList(); @Benchmark public List parallel_collect(BenchmarkState state) { @@ -62,14 +62,14 @@ public List parallel_batch_collect(BenchmarkState state) { public List parallel_streaming(BenchmarkState state) { return source.stream() .collect(ParallelCollectors.parallelToStream(i -> i, state.executor, state.parallelism)) - .collect(toList()); + .toList(); } @Benchmark public List parallel_batch_streaming_collect(BenchmarkState state) { return source.stream() .collect(ParallelCollectors.Batching.parallelToStream(i -> i, state.executor, state.parallelism)) - .collect(toList()); + .toList(); } public static void main(String[] args) throws RunnerException {