Skip to content

Commit

Permalink
Rewrite invalid parallelism tests (#942)
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit authored Sep 16, 2024
1 parent 2d5eaad commit e8aeec7
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 33 deletions.
46 changes: 15 additions & 31 deletions src/test/java/com/pivovarit/collectors/FunctionalTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ Stream<DynamicTest> collectors() {
virtualThreadsTests((m, e, p) -> adapt(parallel(m)), "ParallelCollectors.parallel() [virtual]"),
virtualThreadsTests((m, e, p) -> adapt(parallel(m, p)), "ParallelCollectors.parallel() [virtual]"),
// platform threads
tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> parallel(m, toList(), e), "ParallelCollectors.parallel(toList(), p=inf)", false),
tests((m, e, p) -> parallel(m, toSet(), e), "ParallelCollectors.parallel(toSet(), p=inf)", false),
tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM), true),
tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM), true),
tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM)),
tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM)),
tests((m, e, p) -> parallel(m, toList(), e), "ParallelCollectors.parallel(toList(), p=inf)"),
tests((m, e, p) -> parallel(m, toSet(), e), "ParallelCollectors.parallel(toSet(), p=inf)"),
tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM)),
tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM)),
// platform threads, with batching
batchTests((m, e, p) -> Batching.parallel(m, toList(), e, p), format("ParallelCollectors.Batching.parallel(toList(), p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> Batching.parallel(m, toSet(), e, p), format("ParallelCollectors.Batching.parallel(toSet(), p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> Batching.parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.Batching.parallel(toCollection(), p=%d)", PARALLELISM), true),
batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM), true)
batchTests((m, e, p) -> Batching.parallel(m, toList(), e, p), format("ParallelCollectors.Batching.parallel(toList(), p=%d)", PARALLELISM)),
batchTests((m, e, p) -> Batching.parallel(m, toSet(), e, p), format("ParallelCollectors.Batching.parallel(toSet(), p=%d)", PARALLELISM)),
batchTests((m, e, p) -> Batching.parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.Batching.parallel(toCollection(), p=%d)", PARALLELISM)),
batchTests((m, e, p) -> adapt(Batching.parallel(m, e, p)), format("ParallelCollectors.Batching.parallel(p=%d)", PARALLELISM))
).flatMap(i -> i);
}

Expand Down Expand Up @@ -167,16 +167,12 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThread
);
}

private static <R extends Collection<Integer>> Stream<DynamicTest> tests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean limitedParallelism) {
var tests = of(
private static <R extends Collection<Integer>> Stream<DynamicTest> tests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return of(
shouldStartConsumingImmediately(collector, name),
shouldShortCircuitOnException(collector, name),
shouldInterruptOnException(collector, name)
);

tests = limitedParallelism ? of(shouldRejectInvalidParallelism(collector, name)) : tests;

return tests;
}

private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThreadsStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
Expand All @@ -190,15 +186,12 @@ private static <R extends Collection<Integer>> Stream<DynamicTest> streamingTest
return of(
shouldStartConsumingImmediately(collector, name),
shouldPushElementsToStreamAsSoonAsPossible(collector, name),
shouldShortCircuitOnException(collector, name),
shouldRejectInvalidParallelism(collector, name)
shouldShortCircuitOnException(collector, name)
);
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name, boolean limitedParallelism) {
return Stream.concat(
tests(collector, name, limitedParallelism),
of(shouldProcessOnNThreadsETParallelism(collector, name)));
private static <R extends Collection<Integer>> Stream<DynamicTest> batchTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return Stream.concat(tests(collector, name), of(shouldProcessOnNThreadsETParallelism(collector, name)));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
Expand Down Expand Up @@ -263,15 +256,6 @@ private static <R extends Collection<Integer>> DynamicTest shouldShortCircuitOnE
});
}

private static <R extends Collection<Integer>> DynamicTest shouldRejectInvalidParallelism(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return dynamicTest(format("%s: should reject invalid parallelism", name), () -> {
withExecutor(e -> {
assertThatThrownBy(() -> collector.apply(i -> i, e, -1))
.isExactlyInstanceOf(IllegalArgumentException.class);
});
});
}

private static <R extends Collection<Integer>> DynamicTest shouldStartConsumingImmediately(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return dynamicTest(format("%s: should start consuming immediately", name), () -> {
try (var e = Executors.newCachedThreadPool()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.pivovarit.collectors.ParallelCollectors;
import com.pivovarit.collectors.TestUtils;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

Expand All @@ -20,6 +19,7 @@
import static com.pivovarit.collectors.test.BasicParallelismTest.CollectorDefinition.collector;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;

class BasicParallelismTest {
Expand Down Expand Up @@ -61,11 +61,22 @@ Stream<DynamicTest> shouldProcessAllElementsWithMaxParallelism() {
Stream<DynamicTest> shouldRespectMaxParallelism() {
return allBounded()
.map(c -> DynamicTest.dynamicTest(c.name(), () -> {
var duration = timed(() -> IntStream.range(0, 10).boxed().collect(c.factory().collector(i -> TestUtils.returnWithDelay(i, Duration.ofMillis(100)), 2)));
var duration = timed(() -> IntStream.range(0, 10).boxed()
.collect(c.factory().collector(i -> TestUtils.returnWithDelay(i, Duration.ofMillis(100)), 2)));
assertThat(duration).isCloseTo(Duration.ofMillis(500), Duration.ofMillis(100));
}));
}

@TestFactory
Stream<DynamicTest> shouldRejectInvalidParallelism() {
return allBounded()
.flatMap(c -> Stream.of(-1, 0)
.map(p -> DynamicTest.dynamicTest("%s [p=%d]".formatted(c.name(), p), () -> {
assertThatThrownBy(() -> Stream.of(1).collect(c.factory().collector(i -> i, p)))
.isExactlyInstanceOf(IllegalArgumentException.class);
})));
}

protected record CollectorDefinition<T, R>(String name, CollectorFactory<T, R> factory) {
static <T, R> CollectorDefinition<T, R> collector(String name, CollectorFactory<T, R> collector) {
return new CollectorDefinition<>(name, collector);
Expand Down

0 comments on commit e8aeec7

Please sign in to comment.