From f2317acf8e488fde02b8be8f0e72986889a3cd44 Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Sun, 15 Sep 2024 18:02:09 +0200 Subject: [PATCH] Validate Executors when used with unbounded Dispatcher and batching collectors (#932) --- .../collectors/AsyncParallelCollector.java | 7 +++ .../com/pivovarit/collectors/Dispatcher.java | 17 +----- .../collectors/ParallelStreamCollector.java | 22 +++---- .../pivovarit/collectors/Preconditions.java | 24 ++++++++ .../test/ExecutorValidationTest.java | 59 +++++++++++++------ 5 files changed, 86 insertions(+), 43 deletions(-) create mode 100644 src/main/java/com/pivovarit/collectors/Preconditions.java diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index ab7f4a04..075a74e8 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -17,6 +17,7 @@ import static com.pivovarit.collectors.BatchingSpliterator.batching; import static com.pivovarit.collectors.BatchingSpliterator.partitioned; +import static com.pivovarit.collectors.Preconditions.requireValidExecutor; import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; @@ -109,6 +110,7 @@ private static CompletableFuture> combine(List Collector>> collectingToStream(Function mapper, Executor executor) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); + requireValidExecutor(executor); return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), Function.identity()); } @@ -117,6 +119,7 @@ private static CompletableFuture> combine(List i) @@ -144,6 +147,7 @@ private static CompletableFuture> combine(List(mapper, Dispatcher.from(executor), s -> s.collect(collector)); } @@ -153,6 +157,7 @@ private static CompletableFuture> combine(List s.collect(collector)) @@ -185,6 +190,7 @@ private BatchingCollectors() { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return parallelism == 1 ? asyncCollector(mapper, executor, s -> s.collect(collector)) @@ -197,6 +203,7 @@ private BatchingCollectors() { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return parallelism == 1 ? asyncCollector(mapper, executor, i -> i) diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index de4db63a..3bb2b1a8 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -14,6 +14,8 @@ import java.util.function.Function; import java.util.function.Supplier; +import static com.pivovarit.collectors.Preconditions.requireValidExecutor; + /** * @author Grzegorz Piwowarek */ @@ -48,6 +50,7 @@ private Dispatcher(int permits) { } private Dispatcher(Executor executor) { + requireValidExecutor(executor); this.executor = executor; this.limiter = null; } @@ -171,20 +174,6 @@ private static ExecutorService defaultExecutorService() { return Executors.newVirtualThreadPerTaskExecutor(); } - private static void requireValidExecutor(Executor executor) { - if (executor instanceof ThreadPoolExecutor tpe) { - switch (tpe.getRejectedExecutionHandler()) { - case ThreadPoolExecutor.DiscardPolicy __ -> - throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); - case ThreadPoolExecutor.DiscardOldestPolicy __ -> - throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); - default -> { - // no-op - } - } - } - } - private static void retry(Runnable runnable) { try { runnable.run(); diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 64fe8a84..7772a1c1 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -19,6 +19,7 @@ import static com.pivovarit.collectors.BatchingSpliterator.partitioned; import static com.pivovarit.collectors.CompletionStrategy.ordered; import static com.pivovarit.collectors.CompletionStrategy.unordered; +import static com.pivovarit.collectors.Preconditions.requireValidExecutor; import static java.util.Collections.emptySet; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -100,6 +101,7 @@ public Set characteristics() { static Collector> streaming(Function mapper, Executor executor) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); + requireValidExecutor(executor); return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor)); } @@ -108,6 +110,7 @@ public Set characteristics() { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism)); } @@ -128,15 +131,16 @@ public Set characteristics() { static Collector> streamingOrdered(Function mapper, Executor executor) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); + requireValidExecutor(executor); return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor)); } - static Collector> streamingOrdered(Function mapper, Executor executor, - int parallelism) { + static Collector> streamingOrdered(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism)); } @@ -146,30 +150,29 @@ static final class BatchingCollectors { private BatchingCollectors() { } - static Collector> streaming(Function mapper, Executor executor, - int parallelism) { + static Collector> streaming(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return parallelism == 1 ? syncCollector(mapper) : batchingCollector(mapper, executor, parallelism); } - static Collector> streamingOrdered(Function mapper, Executor executor, - int parallelism) { + static Collector> streamingOrdered(Function mapper, Executor executor, int parallelism) { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); + requireValidExecutor(executor); return parallelism == 1 ? syncCollector(mapper) : batchingCollector(mapper, executor, parallelism); } - private static Collector> batchingCollector(Function mapper, - Executor executor, int parallelism) { + private static Collector> batchingCollector(Function mapper, Executor executor, int parallelism) { return collectingAndThen( toList(), list -> { @@ -195,8 +198,7 @@ private BatchingCollectors() { private static Collector, Stream> syncCollector(Function mapper) { return Collector.of(Stream::builder, (rs, t) -> rs.add(mapper.apply(t)), (rs, rs2) -> { - throw new UnsupportedOperationException( - "Using parallel stream with parallel collectors is a bad idea"); + throw new UnsupportedOperationException("Using parallel stream with parallel collectors is a bad idea"); }, Stream.Builder::build); } } diff --git a/src/main/java/com/pivovarit/collectors/Preconditions.java b/src/main/java/com/pivovarit/collectors/Preconditions.java new file mode 100644 index 00000000..65d9ae61 --- /dev/null +++ b/src/main/java/com/pivovarit/collectors/Preconditions.java @@ -0,0 +1,24 @@ +package com.pivovarit.collectors; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadPoolExecutor; + +final class Preconditions { + + private Preconditions() { + } + + static void requireValidExecutor(Executor executor) { + if (executor instanceof ThreadPoolExecutor tpe) { + switch (tpe.getRejectedExecutionHandler()) { + case ThreadPoolExecutor.DiscardPolicy __ -> + throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); + case ThreadPoolExecutor.DiscardOldestPolicy __ -> + throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); + default -> { + // no-op + } + } + } + } +} diff --git a/src/test/java/com/pivovarit/collectors/test/ExecutorValidationTest.java b/src/test/java/com/pivovarit/collectors/test/ExecutorValidationTest.java index 2ff8daa3..cf9ce031 100644 --- a/src/test/java/com/pivovarit/collectors/test/ExecutorValidationTest.java +++ b/src/test/java/com/pivovarit/collectors/test/ExecutorValidationTest.java @@ -4,7 +4,8 @@ import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; -import java.util.concurrent.ExecutorService; +import java.util.List; +import java.util.concurrent.Executor; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -12,30 +13,50 @@ import java.util.stream.Collector; import java.util.stream.Stream; -import static java.util.stream.Stream.of; +import static com.pivovarit.collectors.test.ExecutorValidationTest.CollectorDefinition.collector; +import static java.util.stream.Collectors.collectingAndThen; import static org.assertj.core.api.Assertions.assertThatThrownBy; class ExecutorValidationTest { + private static Stream> allWithCustomExecutors() { + return Stream.of( + collector("parallel(e)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e), c -> c.thenApply(Stream::toList).join())), + collector("parallel(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e, 1), c -> c.thenApply(Stream::toList).join())), + collector("parallel(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e, 4), c -> c.thenApply(Stream::toList).join())), + collector("parallel(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e, 1), c -> c.thenApply(Stream::toList).join())), + collector("parallel(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e, 4), c -> c.thenApply(Stream::toList).join())), + collector("parallelToStream(e)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e), Stream::toList)), + collector("parallelToStream(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e, 1), Stream::toList)), + collector("parallelToStream(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e, 4), Stream::toList)), + collector("parallelToStream(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e, 1), Stream::toList)), + collector("parallelToStream(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e, 4), Stream::toList)), + collector("parallelToOrderedStream(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e, 1), Stream::toList)), + collector("parallelToOrderedStream(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e, 4), Stream::toList)), + collector("parallelToOrderedStream(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e, 1), Stream::toList)), + collector("parallelToOrderedStream(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e, 4), Stream::toList)) + ); + } + @TestFactory - Stream shouldStartProcessingElementsTests() { - return of( - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallel(i -> i, e, 2), "parallel"), - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToStream(i -> i, e, 2), "parallelToStream"), - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream"), - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallel(i -> i, e, 2), "parallel (batching)"), - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToStream(i -> i, e, 2), "parallelToStream (batching)"), - shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream (batching)") - ).flatMap(i -> i); + Stream shouldRejectInvalidRejectedExecutionHandlerFactory() { + return allWithCustomExecutors() + .flatMap(c -> Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy()) + .map(dp -> DynamicTest.dynamicTest("%s : %s".formatted(c.name(), dp.getClass().getSimpleName()), () -> { + try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) { + assertThatThrownBy(() -> Stream.of(1, 2, 3).collect(c.factory().collector(i -> i, e))).isExactlyInstanceOf(IllegalArgumentException.class); + } + }))); + } + + protected record CollectorDefinition(String name, CollectorFactory factory) { + static CollectorDefinition collector(String name, CollectorFactory factory) { + return new CollectorDefinition<>(name, factory); + } } - private static Stream shouldRejectInvalidRejectedExecutionHandler(Function> collector, String name) { - return Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy()) - .map(dp -> DynamicTest.dynamicTest(name + " : " + dp.getClass().getSimpleName(), () -> { - try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) { - assertThatThrownBy(() -> Stream.of(1, 2, 3) - .collect(collector.apply(e))).isExactlyInstanceOf(IllegalArgumentException.class); - } - })); + @FunctionalInterface + private interface CollectorFactory { + Collector> collector(Function f, Executor executor); } }