Skip to content

Commit

Permalink
Reject invalid Executor's RejectedExecutionHandlers (#873)
Browse files Browse the repository at this point in the history
Disallow `ThreadPoolExecutor.DiscardOldestPolicy` and `ThreadPoolExecutor.DiscardPolicy` since they are fundamentally unsafe and can induce permanent inconsistencies when dropping some of the tasks
  • Loading branch information
pivovarit authored May 1, 2024
1 parent 5da805b commit 03d52fc
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -36,12 +37,8 @@ private Dispatcher() {
this.limiter = null;
}

private Dispatcher(int permits) {
this.executor = defaultExecutorService();
this.limiter = new Semaphore(permits);
}

private Dispatcher(Executor executor, int permits) {
requireValidExecutor(executor);
this.executor = executor;
this.limiter = new Semaphore(permits);
}
Expand Down Expand Up @@ -157,4 +154,18 @@ public boolean cancel(boolean mayInterruptIfRunning) {
private static ExecutorService defaultExecutorService() {
return Executors.newVirtualThreadPerTaskExecutor();
}

private static void requireValidExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
switch (tpe.getRejectedExecutionHandler()) {
case ThreadPoolExecutor.DiscardPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
case ThreadPoolExecutor.DiscardOldestPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
default -> {
// no-op
}
}
}
}
}
40 changes: 40 additions & 0 deletions src/test/java/com/pivovarit/collectors/ExecutorValidationTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.pivovarit.collectors;

import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static java.util.stream.Stream.of;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ExecutorValidationTest {

@TestFactory
Stream<DynamicTest> shouldStartProcessingElementsTests() {
return of(
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallel(i -> i, e, 2), "parallel"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToStream(i -> i, e, 2), "parallelToStream"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallel(i -> i, e, 2), "parallel (batching)"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToStream(i -> i, e, 2), "parallelToStream (batching)"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream (batching)")
).flatMap(i -> i);
}

private static Stream<DynamicTest> shouldRejectInvalidRejectedExecutionHandler(Function<ExecutorService, Collector<Integer, ?, ?>> collector, String name) {
return Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy())
.map(dp -> DynamicTest.dynamicTest(name + " : " + dp.getClass().getSimpleName(), () -> {
try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) {
assertThatThrownBy(() -> Stream.of(1, 2, 3)
.collect(collector.apply(e))).isExactlyInstanceOf(IllegalArgumentException.class);
}
}));
}
}

0 comments on commit 03d52fc

Please sign in to comment.