Skip to content

Commit

Permalink
Rewrite shortcircuiting tests (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit authored Sep 16, 2024
1 parent 39af22d commit a520c44
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 61 deletions.
61 changes: 2 additions & 59 deletions src/test/java/com/pivovarit/collectors/FunctionalTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
Expand All @@ -22,15 +21,12 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.pivovarit.collectors.ParallelCollectors.parallel;
import static com.pivovarit.collectors.ParallelCollectors.parallelToOrderedStream;
import static com.pivovarit.collectors.ParallelCollectors.parallelToStream;
import static com.pivovarit.collectors.TestUtils.incrementAndThrow;
import static com.pivovarit.collectors.TestUtils.returnWithDelay;
import static com.pivovarit.collectors.TestUtils.runWithExecutor;
import static com.pivovarit.collectors.TestUtils.withExecutor;
import static java.lang.String.format;
import static java.time.Duration.ofMillis;
Expand All @@ -42,7 +38,6 @@
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Stream.of;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.DynamicTest.dynamicTest;

Expand All @@ -56,22 +51,6 @@ class FunctionalTest {
@TestFactory
Stream<DynamicTest> collectors() {
return of(
// virtual threads
virtualThreadsTests((m, e, p) -> parallel(m, toList()), "ParallelCollectors.parallel(toList()) [virtual]"),
virtualThreadsTests((m, e, p) -> parallel(m, toList(), p), "ParallelCollectors.parallel(toList()) [virtual]"),
virtualThreadsTests((m, e, p) -> parallel(m, toSet()), "ParallelCollectors.parallel(toSet()) [virtual]"),
virtualThreadsTests((m, e, p) -> parallel(m, toSet(), p), "ParallelCollectors.parallel(toSet()) [virtual]"),
virtualThreadsTests((m, e, p) -> parallel(m, toCollection(LinkedList::new)), "ParallelCollectors.parallel(toCollection()) [virtual]"),
virtualThreadsTests((m, e, p) -> parallel(m, toCollection(LinkedList::new), p), "ParallelCollectors.parallel(toCollection()) [virtual]"),
virtualThreadsTests((m, e, p) -> adapt(parallel(m)), "ParallelCollectors.parallel() [virtual]"),
virtualThreadsTests((m, e, p) -> adapt(parallel(m, p)), "ParallelCollectors.parallel() [virtual]"),
// platform threads
tests((m, e, p) -> parallel(m, toList(), e, p), format("ParallelCollectors.parallel(toList(), p=%d)", PARALLELISM)),
tests((m, e, p) -> parallel(m, toSet(), e, p), format("ParallelCollectors.parallel(toSet(), p=%d)", PARALLELISM)),
tests((m, e, p) -> parallel(m, toList(), e), "ParallelCollectors.parallel(toList(), p=inf)"),
tests((m, e, p) -> parallel(m, toSet(), e), "ParallelCollectors.parallel(toSet(), p=inf)"),
tests((m, e, p) -> parallel(m, toCollection(LinkedList::new), e, p), format("ParallelCollectors.parallel(toCollection(), p=%d)", PARALLELISM)),
tests((m, e, p) -> adapt(parallel(m, e, p)), format("ParallelCollectors.parallel(p=%d)", PARALLELISM)),
// platform threads, with batching
batchTests((m, e, p) -> Batching.parallel(m, toList(), e, p), format("ParallelCollectors.Batching.parallel(toList(), p=%d)", PARALLELISM)),
batchTests((m, e, p) -> Batching.parallel(m, toSet(), e, p), format("ParallelCollectors.Batching.parallel(toSet(), p=%d)", PARALLELISM)),
Expand All @@ -83,9 +62,6 @@ Stream<DynamicTest> collectors() {
@TestFactory
Stream<DynamicTest> streaming_collectors() {
return of(
// virtual threads
virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToStream(m)), "ParallelCollectors.parallelToStream() [virtual]"),
virtualThreadsStreamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m)), "ParallelCollectors.parallelToOrderedStream() [virtual]"),
// platform threads
streamingTests((m, e, p) -> adaptAsync(parallelToStream(m, e, p)), format("ParallelCollectors.parallelToStream(p=%d)", PARALLELISM)),
streamingTests((m, e, p) -> adaptAsync(parallelToOrderedStream(m, e, p)), format("ParallelCollectors.parallelToOrderedStream(p=%d)", PARALLELISM))
Expand Down Expand Up @@ -157,27 +133,12 @@ void shouldExecuteEagerlyOnProvidedThreadPool() {
}
}

private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThreadsTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return of(shouldShortCircuitOnException(collector, name));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> tests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return of(shouldShortCircuitOnException(collector, name));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> virtualThreadsStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return of(shouldShortCircuitOnException(collector, name));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> streamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return of(
shouldPushElementsToStreamAsSoonAsPossible(collector, name),
shouldShortCircuitOnException(collector, name)
);
return of(shouldPushElementsToStreamAsSoonAsPossible(collector, name));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return Stream.concat(tests(collector, name), of(shouldProcessOnNThreadsETParallelism(collector, name)));
return of(shouldProcessOnNThreadsETParallelism(collector, name));
}

private static <R extends Collection<Integer>> Stream<DynamicTest> batchStreamingTests(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
Expand Down Expand Up @@ -224,24 +185,6 @@ private static <R extends Collection<Integer>> DynamicTest shouldProcessOnNThrea
});
}

private static <R extends Collection<Integer>> DynamicTest shouldShortCircuitOnException(CollectorSupplier<Function<Integer, Integer>, Executor, Integer, Collector<Integer, ?, CompletableFuture<R>>> collector, String name) {
return dynamicTest(format("%s: should short circuit on exception", name), () -> {
List<Integer> elements = IntStream.range(0, 100).boxed().toList();
int size = 4;

runWithExecutor(e -> {
AtomicInteger counter = new AtomicInteger();

assertThatThrownBy(elements.stream()
.collect(collector.apply(i -> incrementAndThrow(counter), e, PARALLELISM))::join)
.isInstanceOf(CompletionException.class)
.hasCauseExactlyInstanceOf(IllegalArgumentException.class);

assertThat(counter.longValue()).isLessThan(elements.size());
}, size);
});
}

private static Collector<Integer, ?, CompletableFuture<Collection<Integer>>> adapt(Collector<Integer, ?, CompletableFuture<Stream<Integer>>> input) {
return collectingAndThen(input, stream -> stream.thenApply(Stream::toList));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.pivovarit.collectors.test.ExceptionPropagationTest.CollectorDefinition.collector;
import static com.pivovarit.collectors.TestUtils.incrementAndThrow;
import static com.pivovarit.collectors.test.ExceptionHandlingTest.CollectorDefinition.collector;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ExceptionPropagationTest {
class ExceptionHandlingTest {

private static Stream<CollectorDefinition<Integer, Integer>> all() {
return Stream.of(
Expand Down Expand Up @@ -59,6 +62,22 @@ Stream<DynamicTest> shouldPropagateExceptionFactory() {
}));
}

@TestFactory
Stream<DynamicTest> shouldShortcircuitOnException() {
return all()
.map(c -> DynamicTest.dynamicTest(c.name(), () -> {
List<Integer> elements = IntStream.range(0, 100).boxed().toList();
AtomicInteger counter = new AtomicInteger();

assertThatThrownBy(() -> elements.stream()
.collect(c.collector().apply(i -> incrementAndThrow(counter))))
.isInstanceOf(CompletionException.class)
.hasCauseExactlyInstanceOf(IllegalArgumentException.class);

assertThat(counter.longValue()).isLessThan(elements.size());
}));
}

record CollectorDefinition<T, R>(String name, Function<Function<T, R>, Collector<T, ?, List<R>>> collector) {
static <T, R> CollectorDefinition<T, R> collector(String name, Function<Function<T, R>, Collector<T, ?, List<R>>> collector) {
return new CollectorDefinition<>(name, collector);
Expand Down

0 comments on commit a520c44

Please sign in to comment.