Skip to content

Commit

Permalink
Unify Collector definitions (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit authored Sep 16, 2024
1 parent 9fbf2cd commit 0ac103c
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 354 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,28 @@
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.pivovarit.collectors.test.BasicParallelismTest.CollectorDefinition.collector;
import static com.pivovarit.collectors.test.Factory.GenericCollector.limitedCollector;
import static com.pivovarit.collectors.test.Factory.e;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;

class BasicParallelismTest {

private static Stream<CollectorDefinition<Integer, Integer>> allBounded() {
private static Stream<Factory.GenericCollector<Factory.CollectorFactoryWithParallelism<Integer, Integer>>> allBounded() {
return Stream.of(
collector("parallel(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallel(f, e(), p), c -> c.join().toList())),
collector("parallel(toList(), e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e(), p), CompletableFuture::join)),
collector("parallel(toList(), e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p), CompletableFuture::join)),
collector("parallelToStream(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e(), p), Stream::toList)),
collector("parallelToStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e(), p), Stream::toList)),
collector("parallelToOrderedStream(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e(), p), Stream::toList)),
collector("parallelToOrderedStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p), Stream::toList))
limitedCollector("parallel(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallel(f, e(), p), c -> c.join().toList())),
limitedCollector("parallel(toList(), e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e(), p), CompletableFuture::join)),
limitedCollector("parallel(toList(), e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p), CompletableFuture::join)),
limitedCollector("parallelToStream(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallelToStream(f, e(), p), Stream::toList)),
limitedCollector("parallelToStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e(), p), Stream::toList)),
limitedCollector("parallelToOrderedStream(e, p)", (f, p) -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e(), p), Stream::toList)),
limitedCollector("parallelToOrderedStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p), Stream::toList))
);
}

Expand Down Expand Up @@ -76,19 +73,9 @@ Stream<DynamicTest> shouldRejectInvalidParallelism() {
})));
}

protected record CollectorDefinition<T, R>(String name, Factory.CollectorFactoryWithParallelism<T, R> factory) {
static <T, R> CollectorDefinition<T, R> collector(String name, Factory.CollectorFactoryWithParallelism<T, R> collector) {
return new CollectorDefinition<>(name, collector);
}
}

private static Executor e() {
return Executors.newCachedThreadPool();
}

private static Duration timed(Supplier<?> action) {
long start = System.currentTimeMillis();
var result = action.get();
var ignored = action.get();
return Duration.ofMillis(System.currentTimeMillis() - start);
}
}
Original file line number Diff line number Diff line change
@@ -1,75 +1,32 @@
package com.pivovarit.collectors.test;

import com.pivovarit.collectors.ParallelCollectors;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.pivovarit.collectors.TestUtils.returnWithDelay;
import static com.pivovarit.collectors.test.BasicProcessingTest.CollectorDefinition.collector;
import static com.pivovarit.collectors.test.Factory.all;
import static com.pivovarit.collectors.test.Factory.allOrdered;
import static java.time.Duration.ofSeconds;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
import static org.awaitility.Awaitility.await;

class BasicProcessingTest {

private static Stream<CollectorDefinition<Integer, Integer>> all() {
return Stream.of(
collector("parallel()", f -> collectingAndThen(ParallelCollectors.parallel(f), c -> c.join().toList())),
collector("parallel(e)", f -> collectingAndThen(ParallelCollectors.parallel(f, e()), c -> c.join().toList())),
collector("parallel(e, p)", f -> collectingAndThen(ParallelCollectors.parallel(f, e(), p()), c -> c.join().toList())),
collector("parallel(toList())", f -> collectingAndThen(ParallelCollectors.parallel(f, toList()), CompletableFuture::join)),
collector("parallel(toList(), e)", f -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e()), CompletableFuture::join)),
collector("parallel(toList(), e, p)", f -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e(), p()), CompletableFuture::join)),
collector("parallel(toList(), e, p) [batching]", f -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p()), CompletableFuture::join)),
collector("parallelToStream()", f -> collectingAndThen(ParallelCollectors.parallelToStream(f), Stream::toList)),
collector("parallelToStream(e)", f -> collectingAndThen(ParallelCollectors.parallelToStream(f, e()), Stream::toList)),
collector("parallelToStream(e, p)", f -> collectingAndThen(ParallelCollectors.parallelToStream(f, e(), p()), Stream::toList)),
collector("parallelToStream(e, p) [batching]", f -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e(), p()), Stream::toList)),
collector("parallelToOrderedStream()", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f), Stream::toList)),
collector("parallelToOrderedStream(e)", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e()), Stream::toList)),
collector("parallelToOrderedStream(e, p)", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e(), p()), Stream::toList)),
collector("parallelToOrderedStream(e, p) [batching]", f -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p()), Stream::toList))
);
}

public static Stream<CollectorDefinition<Integer, Integer>> allOrdered() {
return Stream.of(
collector("parallel()", f -> collectingAndThen(ParallelCollectors.parallel(f), c -> c.join().toList())),
collector("parallel(e)", f -> collectingAndThen(ParallelCollectors.parallel(f, e()), c -> c.join().toList())),
collector("parallel(e, p)", f -> collectingAndThen(ParallelCollectors.parallel(f, e(), p()), c -> c.join().toList())),
collector("parallel(toList())", f -> collectingAndThen(ParallelCollectors.parallel(f, toList()), CompletableFuture::join)),
collector("parallel(toList(), e)", f -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e()), CompletableFuture::join)),
collector("parallel(toList(), e, p)", f -> collectingAndThen(ParallelCollectors.parallel(f, toList(), e(), p()), CompletableFuture::join)),
collector("parallel(toList(), e, p) [batching]", f -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p()), CompletableFuture::join)),
collector("parallelToOrderedStream()", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f), Stream::toList)),
collector("parallelToOrderedStream(e)", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e()), Stream::toList)),
collector("parallelToOrderedStream(e, p)", f -> collectingAndThen(ParallelCollectors.parallelToOrderedStream(f, e(), p()), Stream::toList)),
collector("parallelToOrderedStream(e, p) [batching]", f -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p()), Stream::toList))
);
}

@TestFactory
Stream<DynamicTest> shouldProcessEmpty() {
return all()
.map(c -> DynamicTest.dynamicTest(c.name(), () -> {
assertThat(Stream.<Integer>empty().collect(c.collector().collector(i -> i))).isEmpty();
assertThat(Stream.<Integer>empty().collect(c.factory().collector(i -> i))).isEmpty();
}));
}

Expand All @@ -78,7 +35,7 @@ Stream<DynamicTest> shouldProcessAllElements() {
return all()
.map(c -> DynamicTest.dynamicTest(c.name(), () -> {
var list = IntStream.range(0, 100).boxed().toList();
List<Integer> result = list.stream().collect(c.collector().collector(i -> i));
List<Integer> result = list.stream().collect(c.factory().collector(i -> i));
assertThat(result).containsExactlyInAnyOrderElementsOf(list);
}));
}
Expand All @@ -88,7 +45,7 @@ Stream<DynamicTest> shouldProcessAllElementsInOrder() {
return allOrdered()
.map(c -> DynamicTest.dynamicTest(c.name(), () -> {
var list = IntStream.range(0, 100).boxed().toList();
List<Integer> result = list.stream().collect(c.collector().collector(i -> i));
List<Integer> result = list.stream().collect(c.factory().collector(i -> i));
assertThat(result).containsAnyElementsOf(list);
}));
}
Expand All @@ -102,7 +59,7 @@ Stream<DynamicTest> shouldStartProcessingImmediately() {
Thread.startVirtualThread(() -> {
Stream.iterate(0, i -> i + 1)
.limit(100)
.collect(c.collector().collector(i -> returnWithDelay(counter.incrementAndGet(), ofSeconds(1))));
.collect(c.factory().collector(i -> returnWithDelay(counter.incrementAndGet(), ofSeconds(1))));
});

await()
Expand All @@ -121,7 +78,7 @@ Stream<DynamicTest> shouldInterruptOnException() {
var latch = new CountDownLatch(size);

assertThatThrownBy(() -> IntStream.range(0, size).boxed()
.collect(c.collector().collector(i -> {
.collect(c.factory().collector(i -> {
try {
latch.countDown();
latch.await();
Expand All @@ -139,18 +96,4 @@ Stream<DynamicTest> shouldInterruptOnException() {
await().atMost(1, SECONDS).until(() -> counter.get() == size - 1);
}));
}

record CollectorDefinition<T, R>(String name, Factory.CollectorFactory<T, R> collector) {
static <T, R> CollectorDefinition<T, R> collector(String name, Factory.CollectorFactory<T, R> collector) {
return new CollectorDefinition<>(name, collector);
}
}

private static Executor e() {
return Executors.newCachedThreadPool();
}

private static int p() {
return 4;
}
}
30 changes: 8 additions & 22 deletions src/test/java/com/pivovarit/collectors/test/BatchingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,23 @@
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.TestFactory;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;

import static com.pivovarit.collectors.test.BatchingTest.CollectorDefinition.collector;
import static com.pivovarit.collectors.test.Factory.GenericCollector.limitedCollector;
import static com.pivovarit.collectors.test.Factory.e;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;

class BatchingTest {
private static Stream<BatchingTest.CollectorDefinition<Integer, Integer>> allBatching() {
private static Stream<Factory.GenericCollector<Factory.CollectorFactoryWithParallelism<Integer, Integer>>> allBatching() {
return Stream.of(
collector("parallel(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e(), p), c -> c.thenApply(Stream::toList).join())),
collector("parallel(toList(), e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p), CompletableFuture::join)),
collector("parallelToStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e(), p), Stream::toList)),
collector("parallelToOrderedStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p), Stream::toList))
limitedCollector("parallel(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, e(), p), c -> c.thenApply(Stream::toList).join())),
limitedCollector("parallel(toList(), e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallel(f, toList(), e(), p), CompletableFuture::join)),
limitedCollector("parallelToStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToStream(f, e(), p), Stream::toList)),
limitedCollector("parallelToOrderedStream(e, p) [batching]", (f, p) -> collectingAndThen(ParallelCollectors.Batching.parallelToOrderedStream(f, e(), p), Stream::toList))
);
}

Expand All @@ -37,22 +33,12 @@ Stream<DynamicTest> shouldProcessOnExactlyNThreads() {

Stream.generate(() -> 42)
.limit(100)
.collect(c.collector().collector(i -> {
.collect(c.factory().collector(i -> {
threads.add(Thread.currentThread().getName());
return i;
}, parallelism));

assertThat(threads).hasSizeLessThanOrEqualTo(parallelism);
}));
}

record CollectorDefinition<T, R>(String name, Factory.CollectorFactoryWithParallelism<T, R> collector) {
static <T, R> CollectorDefinition<T, R> collector(String name, Factory.CollectorFactoryWithParallelism<T, R> collector) {
return new CollectorDefinition<>(name, collector);
}
}

private static Executor e() {
return Executors.newCachedThreadPool();
}
}
Loading

0 comments on commit 0ac103c

Please sign in to comment.