diff --git a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java index 2e4bfe26bce4..3775112cf3fc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java @@ -22,6 +22,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.AbstractLongType; import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; @@ -161,13 +162,23 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChann public Work addPage(Page page) { currentPageSizeInBytes = page.getRetainedSizeInBytes(); - return new AddPageWork(page.getBlock(hashChannel)); + Block block = page.getBlock(hashChannel); + if (block instanceof RunLengthEncodedBlock) { + return new AddRunLengthEncodedPageWork((RunLengthEncodedBlock) block); + } + + return new AddPageWork(block); } @Override public Work getGroupIds(Page page) { currentPageSizeInBytes = page.getRetainedSizeInBytes(); + Block block = page.getBlock(hashChannel); + if (block instanceof RunLengthEncodedBlock) { + return new GetRunLengthEncodedGroupIdsWork((RunLengthEncodedBlock) block); + } + return new GetGroupIdsWork(page.getBlock(hashChannel)); } @@ -374,6 +385,47 @@ public Void getResult() } } + private class AddRunLengthEncodedPageWork + implements Work + { + private final RunLengthEncodedBlock block; + + private boolean finished; + + public AddRunLengthEncodedPageWork(RunLengthEncodedBlock block) + { + this.block = requireNonNull(block, "block is null"); + } + + @Override + public boolean process() + { + checkState(!finished); + if (block.getPositionCount() == 0) { + finished = true; + return true; + } + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // Only needs to process the first row since it is Run Length Encoded + putIfAbsent(0, block.getValue()); + finished = true; + + return true; + } + + @Override + public Void getResult() + { + throw new UnsupportedOperationException(); + } + } + private class GetGroupIdsWork implements Work { @@ -422,4 +474,54 @@ public GroupByIdBlock getResult() return new GroupByIdBlock(nextGroupId, blockBuilder.build()); } } + + private class GetRunLengthEncodedGroupIdsWork + implements Work + { + private final RunLengthEncodedBlock block; + + int groupId = -1; + private boolean processFinished; + private boolean resultProduced; + + public GetRunLengthEncodedGroupIdsWork(RunLengthEncodedBlock block) + { + this.block = requireNonNull(block, "block is null"); + } + + @Override + public boolean process() + { + checkState(!processFinished); + if (block.getPositionCount() == 0) { + processFinished = true; + return true; + } + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // Only needs to process the first row since it is Run Length Encoded + groupId = putIfAbsent(0, block.getValue()); + processFinished = true; + return true; + } + + @Override + public GroupByIdBlock getResult() + { + checkState(processFinished); + checkState(!resultProduced); + resultProduced = true; + + return new GroupByIdBlock( + nextGroupId, + new RunLengthEncodedBlock( + BIGINT.createFixedSizeBlockBuilder(1).writeLong(groupId).build(), + block.getPositionCount())); + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHash.java index 417a513cd5ef..33ce704d15e3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHash.java @@ -22,7 +22,9 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; +import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.AbstractLongType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -50,6 +52,7 @@ import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; import static io.trino.jmh.Benchmarks.benchmark; import static io.trino.operator.UpdateMemory.NOOP; @@ -217,11 +220,12 @@ private static void addInputPagesToHash(GroupByHash groupByHash, List page boolean finished; do { finished = work.process(); - } while (!finished); + } + while (!finished); } } - private static List createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled) + private static List createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled, boolean pollute) { List types = Collections.nCopies(channelCount, BIGINT); ImmutableList.Builder pages = ImmutableList.builder(); @@ -230,6 +234,7 @@ private static List createBigintPages(int positionCount, int groupCount, i } PageBuilder pageBuilder = new PageBuilder(types); + int pageCount = 0; for (int position = 0; position < positionCount; position++) { int rand = ThreadLocalRandom.current().nextInt(groupCount); pageBuilder.declarePosition(); @@ -240,8 +245,34 @@ private static List createBigintPages(int positionCount, int groupCount, i BIGINT.writeLong(pageBuilder.getBlockBuilder(channelCount), AbstractLongType.hash(rand)); } if (pageBuilder.isFull()) { - pages.add(pageBuilder.build()); + Page page = pageBuilder.build(); pageBuilder.reset(); + if (pollute) { + if (pageCount % 3 == 0) { + pages.add(page); + } + else if (pageCount % 3 == 1) { + // rle page + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < blocks.length; ++channel) { + blocks[channel] = new RunLengthEncodedBlock(page.getBlock(channel).getSingleValueBlock(0), page.getPositionCount()); + } + pages.add(new Page(blocks)); + } + else { + // dictionary page + int[] positions = IntStream.range(0, page.getPositionCount()).toArray(); + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < page.getChannelCount(); ++channel) { + blocks[channel] = new DictionaryBlock(page.getBlock(channel), positions); + } + pages.add(new Page(blocks)); + } + } + else { + pages.add(page); + } + pageCount++; } } pages.add(pageBuilder.build()); @@ -294,7 +325,7 @@ public static class BaselinePagesData @Setup public void setup() { - pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled); + pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false); } public List getPages() @@ -320,7 +351,12 @@ public static class SingleChannelBenchmarkData @Setup public void setup() { - pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled); + setup(false); + } + + public void setup(boolean pollute) + { + pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled, pollute); types = Collections.nCopies(1, BIGINT); channels = new int[1]; for (int i = 0; i < 1; i++) { @@ -376,7 +412,7 @@ public void setup() break; case "BIGINT": types = Collections.nCopies(channelCount, BIGINT); - pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled); + pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false); break; default: throw new UnsupportedOperationException("Unsupported dataType"); @@ -414,6 +450,16 @@ private static JoinCompiler getJoinCompiler() return new JoinCompiler(TYPE_OPERATORS); } + static { + // pollute BigintGroupByHash profile by different block types + SingleChannelBenchmarkData singleChannelBenchmarkData = new SingleChannelBenchmarkData(); + singleChannelBenchmarkData.setup(true); + BenchmarkGroupByHash hash = new BenchmarkGroupByHash(); + for (int i = 0; i < 5; ++i) { + hash.bigintGroupByHash(singleChannelBenchmarkData); + } + } + public static void main(String[] args) throws RunnerException { @@ -431,6 +477,6 @@ public static void main(String[] args) .withOptions(optionsBuilder -> optionsBuilder .addProfiler(GCProfiler.class) .jvmArgs("-Xmx10g")) - .run(); + .run(); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java index 21aaad8b2644..cd37f2e41756 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java @@ -21,6 +21,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinCompiler; @@ -95,6 +96,34 @@ public void testAddPage() } } + @Test + public void testRunLengthEncodedBigintGroupByHash() + { + GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY); + Block block = BlockAssertions.createLongsBlock(0L); + Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block); + Page page = new Page( + new RunLengthEncodedBlock(block, 2), + new RunLengthEncodedBlock(hashBlock, 2)); + + groupByHash.addPage(page).process(); + + assertEquals(groupByHash.getGroupCount(), 1); + + Work work = groupByHash.getGroupIds(page); + work.process(); + GroupByIdBlock groupIds = work.getResult(); + + assertEquals(groupIds.getGroupCount(), 1); + assertEquals(groupIds.getPositionCount(), 2); + assertEquals(groupIds.getGroupId(0), 0); + assertEquals(groupIds.getGroupId(1), 0); + + List children = groupIds.getChildren(); + assertEquals(children.size(), 1); + assertTrue(children.get(0) instanceof RunLengthEncodedBlock); + } + @Test public void testNullGroup() {