From 7724dd456098c019ff771d592adaa9b711722b29 Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Sun, 6 Dec 2020 15:10:51 +0100 Subject: [PATCH] Don't batch when batch size is equal to one --- .../collectors/AsyncParallelCollector.java | 21 +++++++++---- .../collectors/ParallelStreamCollector.java | 30 +++++++++++++++---- .../pivovarit/collectors/FunctionalTest.java | 14 +++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index 64dc3178..f9361848 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -168,11 +168,22 @@ private BatchingCollectors() { private static Collector> batchingCollector(Function mapper, Executor executor, int parallelism, Function, RR> finisher) { return collectingAndThen( toList(), - list -> partitioned(list, parallelism) - .collect(new AsyncParallelCollector<>( - batching(mapper), - Dispatcher.of(executor, parallelism), - listStream -> finisher.apply(listStream.flatMap(Collection::stream))))); + list -> { + // no sense to repack into batches of size 1 + if (list.size() == parallelism) { + return list.stream() + .collect(new AsyncParallelCollector<>( + mapper, + Dispatcher.of(executor, parallelism), + finisher)); + } else { + return partitioned(list, parallelism) + .collect(new AsyncParallelCollector<>( + batching(mapper), + Dispatcher.of(executor, parallelism), + listStream -> finisher.apply(listStream.flatMap(Collection::stream)))); + } + }); } } } diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 52f8e051..dce357dd 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -2,7 +2,6 @@ import java.util.Collection; import java.util.EnumSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -130,7 +129,7 @@ private BatchingCollectors() { return parallelism == 1 ? syncCollector(mapper) - : batched(new ParallelStreamCollector<>(batching(mapper), unordered(), UNORDERED, executor, parallelism), parallelism); + : batchingCollector(mapper, executor, parallelism); } static Collector> streamingOrdered(Function mapper, Executor executor, int parallelism) { @@ -140,14 +139,33 @@ private BatchingCollectors() { return parallelism == 1 ? syncCollector(mapper) - : batched(new ParallelStreamCollector<>(batching(mapper), ordered(), emptySet(), executor, parallelism), parallelism); + : batchingCollector(mapper, executor, parallelism); } - private static Collector> batched(ParallelStreamCollector, List> downstream, int parallelism) { + private static Collector> batchingCollector(Function mapper, Executor executor, int parallelism) { return collectingAndThen( toList(), - list -> partitioned(list, parallelism) - .collect(collectingAndThen(downstream, s -> s.flatMap(Collection::stream)))); + list -> { + // no sense to repack into batches of size 1 + if (list.size() == parallelism) { + return list.stream() + .collect(new ParallelStreamCollector<>( + mapper, + ordered(), + emptySet(), + executor, + parallelism)); + } else { + return partitioned(list, parallelism) + .collect(collectingAndThen(new ParallelStreamCollector<>( + batching(mapper), + ordered(), + emptySet(), + executor, + parallelism), + s -> s.flatMap(Collection::stream))); + } + }); } private static Collector, Stream> syncCollector(Function mapper) { diff --git a/src/test/java/com/pivovarit/collectors/FunctionalTest.java b/src/test/java/com/pivovarit/collectors/FunctionalTest.java index 335a2d7a..fd7943cb 100644 --- a/src/test/java/com/pivovarit/collectors/FunctionalTest.java +++ b/src/test/java/com/pivovarit/collectors/FunctionalTest.java @@ -137,6 +137,8 @@ private static > Stream tests(Collect Stream tests = of( shouldCollect(collector, name, 1), shouldCollect(collector, name, PARALLELISM), + shouldCollectNElementsWithNParallelism(collector, name, 1), + shouldCollectNElementsWithNParallelism(collector, name, PARALLELISM), shouldCollectToEmpty(collector, name), shouldStartConsumingImmediately(collector, name), shouldTerminateAfterConsumingAllElements(collector, name), @@ -251,6 +253,18 @@ private static > DynamicTest shouldCollect(Collect }); } + private static > DynamicTest shouldCollectNElementsWithNParallelism(CollectorSupplier, Executor, Integer, Collector>> factory, String name, int parallelism) { + return dynamicTest(format("%s: should collect %s elements with parallelism %s", name, parallelism, parallelism), () -> { + + List elements = IntStream.iterate(0, i -> i + 1).limit(parallelism).boxed().collect(toList()); + Collector> ctor = factory.apply(i -> i, executor, parallelism); + Collection result = elements.stream().collect(ctor) + .join(); + + assertThat(result).hasSameElementsAs(elements); + }); + } + private static > DynamicTest shouldTerminateAfterConsumingAllElements(CollectorSupplier, Executor, Integer, Collector>> factory, String name) { return dynamicTest(format("%s: should terminate after consuming all elements", name), () -> { List elements = IntStream.range(0, 10).boxed().collect(toList());