diff --git a/src/test/java/com/pivovarit/collectors/functional/ExecutorPollutionTest.java b/src/test/java/com/pivovarit/collectors/functional/ExecutorPollutionTest.java index b3db029c..f845842b 100644 --- a/src/test/java/com/pivovarit/collectors/functional/ExecutorPollutionTest.java +++ b/src/test/java/com/pivovarit/collectors/functional/ExecutorPollutionTest.java @@ -1,9 +1,17 @@ package com.pivovarit.collectors.functional; import com.pivovarit.collectors.ParallelCollectors; -import org.junit.jupiter.api.DynamicTest; -import org.junit.jupiter.api.TestFactory; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; @@ -15,43 +23,62 @@ import static com.pivovarit.collectors.ParallelCollectors.Batching.parallel; import static java.util.stream.Collectors.toList; -import static java.util.stream.Stream.of; +@ExtendWith(ExecutorPollutionTest.ContextProvider.class) class ExecutorPollutionTest { - @TestFactory - Stream shouldStartProcessingElementsTests() { - return of( - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallel(f, e, p), "parallel#1"), - shouldNotSubmitMoreTasksThanParallelism((f, __, p) -> ParallelCollectors.parallel(f, p), "parallel#2"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallel(f, toList(), e, p), "parallel#3"), - shouldNotSubmitMoreTasksThanParallelism((f, __, p) -> ParallelCollectors.parallel(f, toList(), p), "parallel#4"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallelToStream(f, e, p), "parallelToStream#1"), - shouldNotSubmitMoreTasksThanParallelism((f, __, p) -> ParallelCollectors.parallelToStream(f, p), "parallelToStream#2"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.parallelToOrderedStream(f, e, p), "parallelToOrderedStream#1"), - shouldNotSubmitMoreTasksThanParallelism((f, __, p) -> ParallelCollectors.parallelToOrderedStream(f, p), "parallelToOrderedStream#2"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> parallel(f, e, p), "parallel#1 (batching)"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> parallel(f, toList(), e, p), "parallel#2 (batching)"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.Batching.parallelToStream(f, e, p), "parallelToStream (batching)"), - shouldNotSubmitMoreTasksThanParallelism((f, e, p) -> ParallelCollectors.Batching.parallelToOrderedStream(f, e, p), "parallelToOrderedStream (batching)") - ); + @TestTemplate + void shouldNotPolluteExecutorWhenNoParallelism(CollectorFactory collector) { + try (var e = warmedUp(new ThreadPoolExecutor(1 , 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2)))) { + + var result = Stream.generate(() -> 42) + .limit(1000) + .collect(collector.apply(i -> i, e, 1)); + + switch (result) { + case CompletableFuture cf -> cf.join(); + case Stream s -> s.forEach((__) -> {}); + default -> throw new IllegalStateException("can't happen"); + } + } } - private static DynamicTest shouldNotSubmitMoreTasksThanParallelism(CollectorFactory collector, String name) { - return DynamicTest.dynamicTest(name, () -> { - try (var e = warmedUp(new ThreadPoolExecutor(2, 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2)))) { + @TestTemplate + void shouldNotPolluteExecutorWhenLimitedParallelism(CollectorFactory collector) { + try (var e = warmedUp(new ThreadPoolExecutor(2 , 2, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(2)))) { - var result = Stream.generate(() -> 42) - .limit(1000) - .collect(collector.apply(i -> i, e, 2)); + var result = Stream.generate(() -> 42) + .limit(1000) + .collect(collector.apply(i -> i, e, 2)); - switch (result) { - case CompletableFuture cf -> cf.join(); - case Stream s -> s.forEach((__) -> {}); - default -> throw new IllegalStateException("can't happen"); - } + switch (result) { + case CompletableFuture cf -> cf.join(); + case Stream s -> s.forEach((__) -> {}); + default -> throw new IllegalStateException("can't happen"); } - }); + } + } + + static class ContextProvider implements TestTemplateInvocationContextProvider { + + @Override + public boolean supportsTestTemplate(ExtensionContext context) { + return true; + } + + @Override + public Stream provideTestTemplateInvocationContexts(ExtensionContext context) { + return Stream.of( + collector("parallel()", (f, e, p) -> ParallelCollectors.parallel(f, e, p)), + collector("parallel(toList())", (f, e, p) -> ParallelCollectors.parallel(f, toList(), e, p)), + collector("parallelToStream()", (f, e, p) -> ParallelCollectors.parallelToStream(f, e, p)), + collector("parallelToOrderedStream()", (f, e, p) -> ParallelCollectors.parallelToOrderedStream(f, e, p)), + collector("parallel() (batching)", (f, e, p) -> parallel(f, e, p)), + collector("parallel(toList()) (batching)", (f, e, p) -> parallel(f, toList(), e, p)), + collector("parallelToStream() (batching)", (f, e, p) -> ParallelCollectors.Batching.parallelToStream(f, e, p)), + collector("parallelToOrderedStream() (batching)", (f, e, p) -> ParallelCollectors.Batching.parallelToOrderedStream(f, e, p)) + ); + } } interface CollectorFactory { @@ -64,4 +91,28 @@ private static ThreadPoolExecutor warmedUp(ThreadPoolExecutor e) { } return e; } + + private static TestTemplateInvocationContext collector(String name, CollectorFactory factory) { + return new TestTemplateInvocationContext() { + @Override + public String getDisplayName(int invocationIndex) { + return name + " [" + invocationIndex + "]"; + } + + @Override + public List getAdditionalExtensions() { + return List.of(new ParameterResolver() { + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { + return parameterContext.getParameter().getType().equals(CollectorFactory.class); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { + return factory; + } + }); + } + }; + } }