diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index 843f56d8..348b62e2 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -8,6 +8,7 @@ import java.util.concurrent.FutureTask; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.function.Supplier; @@ -36,12 +37,8 @@ private Dispatcher() { this.limiter = null; } - private Dispatcher(int permits) { - this.executor = defaultExecutorService(); - this.limiter = new Semaphore(permits); - } - private Dispatcher(Executor executor, int permits) { + requireValidExecutor(executor); this.executor = executor; this.limiter = new Semaphore(permits); } @@ -157,4 +154,18 @@ public boolean cancel(boolean mayInterruptIfRunning) { private static ExecutorService defaultExecutorService() { return Executors.newVirtualThreadPerTaskExecutor(); } + + private static void requireValidExecutor(Executor executor) { + if (executor instanceof ThreadPoolExecutor tpe) { + switch (tpe.getRejectedExecutionHandler()) { + case ThreadPoolExecutor.DiscardPolicy __ -> + throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); + case ThreadPoolExecutor.DiscardOldestPolicy __ -> + throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks"); + default -> { + // no-op + } + } + } + } } diff --git a/src/test/java/com/pivovarit/collectors/ExecutorValidationTest.java b/src/test/java/com/pivovarit/collectors/ExecutorValidationTest.java new file mode 100644 index 00000000..88259bd7 --- /dev/null +++ b/src/test/java/com/pivovarit/collectors/ExecutorValidationTest.java @@ -0,0 +1,40 @@ +package com.pivovarit.collectors; + +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.TestFactory; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collector; +import java.util.stream.Stream; + +import static java.util.stream.Stream.of; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class ExecutorValidationTest { + + @TestFactory + Stream shouldStartProcessingElementsTests() { + return of( + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallel(i -> i, e, 2), "parallel"), + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToStream(i -> i, e, 2), "parallelToStream"), + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream"), + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallel(i -> i, e, 2), "parallel (batching)"), + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToStream(i -> i, e, 2), "parallelToStream (batching)"), + shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream (batching)") + ).flatMap(i -> i); + } + + private static Stream shouldRejectInvalidRejectedExecutionHandler(Function> collector, String name) { + return Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy()) + .map(dp -> DynamicTest.dynamicTest(name + " : " + dp.getClass().getSimpleName(), () -> { + try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) { + assertThatThrownBy(() -> Stream.of(1, 2, 3) + .collect(collector.apply(e))).isExactlyInstanceOf(IllegalArgumentException.class); + } + })); + } +}