Skip to content

Commit

Permalink
Validate Executors when used with unbounded Dispatcher and batching c…
Browse files Browse the repository at this point in the history
…ollectors (#932)
  • Loading branch information
pivovarit authored Sep 15, 2024
1 parent 610aba8 commit f2317ac
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.Preconditions.requireValidExecutor;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
Expand Down Expand Up @@ -109,6 +110,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper, Executor executor) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidExecutor(executor);

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), Function.identity());
}
Expand All @@ -117,6 +119,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
Expand Down Expand Up @@ -144,6 +147,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(collector, "collector can't be null");
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidExecutor(executor);

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), s -> s.collect(collector));
}
Expand All @@ -153,6 +157,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
Expand Down Expand Up @@ -185,6 +190,7 @@ private BatchingCollectors() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
Expand All @@ -197,6 +203,7 @@ private BatchingCollectors() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
Expand Down
17 changes: 3 additions & 14 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import java.util.function.Function;
import java.util.function.Supplier;

import static com.pivovarit.collectors.Preconditions.requireValidExecutor;

/**
* @author Grzegorz Piwowarek
*/
Expand Down Expand Up @@ -48,6 +50,7 @@ private Dispatcher(int permits) {
}

private Dispatcher(Executor executor) {
requireValidExecutor(executor);
this.executor = executor;
this.limiter = null;
}
Expand Down Expand Up @@ -171,20 +174,6 @@ private static ExecutorService defaultExecutorService() {
return Executors.newVirtualThreadPerTaskExecutor();
}

private static void requireValidExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
switch (tpe.getRejectedExecutionHandler()) {
case ThreadPoolExecutor.DiscardPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
case ThreadPoolExecutor.DiscardOldestPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
default -> {
// no-op
}
}
}
}

private static void retry(Runnable runnable) {
try {
runnable.run();
Expand Down
22 changes: 12 additions & 10 deletions src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.CompletionStrategy.ordered;
import static com.pivovarit.collectors.CompletionStrategy.unordered;
import static com.pivovarit.collectors.Preconditions.requireValidExecutor;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
Expand Down Expand Up @@ -100,6 +101,7 @@ public Set<Characteristics> characteristics() {
static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidExecutor(executor);

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor));
}
Expand All @@ -108,6 +110,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism));
}
Expand All @@ -128,15 +131,16 @@ public Set<Characteristics> characteristics() {
static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidExecutor(executor);

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor,
int parallelism) {
static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism));
}
Expand All @@ -146,30 +150,29 @@ static final class BatchingCollectors {
private BatchingCollectors() {
}

static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor,
int parallelism) {
static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? syncCollector(mapper)
: batchingCollector(mapper, executor, parallelism);
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor,
int parallelism) {
static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);
requireValidExecutor(executor);

return parallelism == 1
? syncCollector(mapper)
: batchingCollector(mapper, executor, parallelism);
}

private static <T, R> Collector<T, ?, Stream<R>> batchingCollector(Function<T, R> mapper,
Executor executor, int parallelism) {
private static <T, R> Collector<T, ?, Stream<R>> batchingCollector(Function<T, R> mapper, Executor executor, int parallelism) {
return collectingAndThen(
toList(),
list -> {
Expand All @@ -195,8 +198,7 @@ private BatchingCollectors() {

private static <T, R> Collector<T, Stream.Builder<R>, Stream<R>> syncCollector(Function<T, R> mapper) {
return Collector.of(Stream::builder, (rs, t) -> rs.add(mapper.apply(t)), (rs, rs2) -> {
throw new UnsupportedOperationException(
"Using parallel stream with parallel collectors is a bad idea");
throw new UnsupportedOperationException("Using parallel stream with parallel collectors is a bad idea");
}, Stream.Builder::build);
}
}
Expand Down
24 changes: 24 additions & 0 deletions src/main/java/com/pivovarit/collectors/Preconditions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.pivovarit.collectors;

import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;

final class Preconditions {

private Preconditions() {
}

static void requireValidExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
switch (tpe.getRejectedExecutionHandler()) {
case ThreadPoolExecutor.DiscardPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
case ThreadPoolExecutor.DiscardOldestPolicy __ ->
throw new IllegalArgumentException("Executor's RejectedExecutionHandler can't discard tasks");
default -> {
// no-op
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,59 @@
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.concurrent.ExecutorService;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static java.util.stream.Stream.of;
import static com.pivovarit.collectors.test.ExecutorValidationTest.CollectorDefinition.collector;
import static java.util.stream.Collectors.collectingAndThen;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ExecutorValidationTest {

private static Stream<CollectorDefinition<Integer, Integer>> allWithCustomExecutors() {
return Stream.of(
collector("parallel(e)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e), c -> c.thenApply(Stream::toList).join())),
collector("parallel(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e, 1), c -> c.thenApply(Stream::toList).join())),
collector("parallel(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallel(f, e, 4), c -> c.thenApply(Stream::toList).join())),
collector("parallel(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e, 1), c -> c.thenApply(Stream::toList).join())),
collector("parallel(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e, 4), c -> c.thenApply(Stream::toList).join())),
collector("parallelToStream(e)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e), Stream::toList)),
collector("parallelToStream(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e, 1), Stream::toList)),
collector("parallelToStream(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e, 4), Stream::toList)),
collector("parallelToStream(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e, 1), Stream::toList)),
collector("parallelToStream(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e, 4), Stream::toList)),
collector("parallelToOrderedStream(e, p=1)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e, 1), Stream::toList)),
collector("parallelToOrderedStream(e, p=4)", (f, e) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e, 4), Stream::toList)),
collector("parallelToOrderedStream(e, p=1) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e, 1), Stream::toList)),
collector("parallelToOrderedStream(e, p=4) [batching]", (f, e) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e, 4), Stream::toList))
);
}

@TestFactory
Stream<DynamicTest> shouldStartProcessingElementsTests() {
return of(
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallel(i -> i, e, 2), "parallel"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToStream(i -> i, e, 2), "parallelToStream"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallel(i -> i, e, 2), "parallel (batching)"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToStream(i -> i, e, 2), "parallelToStream (batching)"),
shouldRejectInvalidRejectedExecutionHandler(e -> ParallelCollectors.Batching.parallelToOrderedStream(i -> i, e, 2), "parallelToOrderedStream (batching)")
).flatMap(i -> i);
Stream<DynamicTest> shouldRejectInvalidRejectedExecutionHandlerFactory() {
return allWithCustomExecutors()
.flatMap(c -> Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy())
.map(dp -> DynamicTest.dynamicTest("%s : %s".formatted(c.name(), dp.getClass().getSimpleName()), () -> {
try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) {
assertThatThrownBy(() -> Stream.of(1, 2, 3).collect(c.factory().collector(i -> i, e))).isExactlyInstanceOf(IllegalArgumentException.class);
}
})));
}

protected record CollectorDefinition<T, R>(String name, CollectorFactory<T, R> factory) {
static <T, R> CollectorDefinition<T, R> collector(String name, CollectorFactory<T, R> factory) {
return new CollectorDefinition<>(name, factory);
}
}

private static Stream<DynamicTest> shouldRejectInvalidRejectedExecutionHandler(Function<ExecutorService, Collector<Integer, ?, ?>> collector, String name) {
return Stream.of(new ThreadPoolExecutor.DiscardOldestPolicy(), new ThreadPoolExecutor.DiscardPolicy())
.map(dp -> DynamicTest.dynamicTest(name + " : " + dp.getClass().getSimpleName(), () -> {
try (var e = new ThreadPoolExecutor(2, 2000, 0, TimeUnit.MILLISECONDS, new SynchronousQueue<>(), dp)) {
assertThatThrownBy(() -> Stream.of(1, 2, 3)
.collect(collector.apply(e))).isExactlyInstanceOf(IllegalArgumentException.class);
}
}));
@FunctionalInterface
private interface CollectorFactory<T, R> {
Collector<T, ?, List<R>> collector(Function<T, R> f, Executor executor);
}
}

0 comments on commit f2317ac

Please sign in to comment.