Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
Expand Down Expand Up @@ -161,13 +162,23 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChann
public Work<?> addPage(Page page)
{
currentPageSizeInBytes = page.getRetainedSizeInBytes();
return new AddPageWork(page.getBlock(hashChannel));
Block block = page.getBlock(hashChannel);
if (block instanceof RunLengthEncodedBlock) {
return new AddRunLengthEncodedPageWork((RunLengthEncodedBlock) block);
}

return new AddPageWork(block);
}

@Override
public Work<GroupByIdBlock> getGroupIds(Page page)
{
currentPageSizeInBytes = page.getRetainedSizeInBytes();
Block block = page.getBlock(hashChannel);
if (block instanceof RunLengthEncodedBlock) {
return new GetRunLengthEncodedGroupIdsWork((RunLengthEncodedBlock) block);
}

return new GetGroupIdsWork(page.getBlock(hashChannel));
}

Expand Down Expand Up @@ -374,6 +385,47 @@ public Void getResult()
}
}

private class AddRunLengthEncodedPageWork
Copy link
Copy Markdown
Member Author

@sopel39 sopel39 Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is based on io.trino.operator.MultiChannelGroupByHash.AddRunLengthEncodedPageWork code

implements Work<Void>
{
private final RunLengthEncodedBlock block;

private boolean finished;

public AddRunLengthEncodedPageWork(RunLengthEncodedBlock block)
{
this.block = requireNonNull(block, "block is null");
}

@Override
public boolean process()
{
checkState(!finished);
if (block.getPositionCount() == 0) {
finished = true;
return true;
}

// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

// Only needs to process the first row since it is Run Length Encoded
putIfAbsent(0, block.getValue());
finished = true;

return true;
}

@Override
public Void getResult()
{
throw new UnsupportedOperationException();
}
}

private class GetGroupIdsWork
implements Work<GroupByIdBlock>
{
Expand Down Expand Up @@ -422,4 +474,54 @@ public GroupByIdBlock getResult()
return new GroupByIdBlock(nextGroupId, blockBuilder.build());
}
}

private class GetRunLengthEncodedGroupIdsWork
implements Work<GroupByIdBlock>
{
private final RunLengthEncodedBlock block;

int groupId = -1;
private boolean processFinished;
private boolean resultProduced;

public GetRunLengthEncodedGroupIdsWork(RunLengthEncodedBlock block)
{
this.block = requireNonNull(block, "block is null");
}

@Override
public boolean process()
{
checkState(!processFinished);
if (block.getPositionCount() == 0) {
processFinished = true;
return true;
}

// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

// Only needs to process the first row since it is Run Length Encoded
groupId = putIfAbsent(0, block.getValue());
processFinished = true;
return true;
}

@Override
public GroupByIdBlock getResult()
{
checkState(processFinished);
checkState(!resultProduced);
resultProduced = true;

return new GroupByIdBlock(
nextGroupId,
new RunLengthEncodedBlock(
BIGINT.createFixedSizeBlockBuilder(1).writeLong(groupId).build(),
block.getPositionCount()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
Expand Down Expand Up @@ -50,6 +52,7 @@
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import static io.trino.jmh.Benchmarks.benchmark;
import static io.trino.operator.UpdateMemory.NOOP;
Expand Down Expand Up @@ -217,11 +220,12 @@ private static void addInputPagesToHash(GroupByHash groupByHash, List<Page> page
boolean finished;
do {
finished = work.process();
} while (!finished);
}
while (!finished);
}
}

private static List<Page> createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled)
private static List<Page> createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled, boolean pollute)
Comment thread
sopel39 marked this conversation as resolved.
{
List<Type> types = Collections.nCopies(channelCount, BIGINT);
ImmutableList.Builder<Page> pages = ImmutableList.builder();
Expand All @@ -230,6 +234,7 @@ private static List<Page> createBigintPages(int positionCount, int groupCount, i
}

PageBuilder pageBuilder = new PageBuilder(types);
int pageCount = 0;
for (int position = 0; position < positionCount; position++) {
int rand = ThreadLocalRandom.current().nextInt(groupCount);
pageBuilder.declarePosition();
Expand All @@ -240,8 +245,34 @@ private static List<Page> createBigintPages(int positionCount, int groupCount, i
BIGINT.writeLong(pageBuilder.getBlockBuilder(channelCount), AbstractLongType.hash(rand));
}
if (pageBuilder.isFull()) {
pages.add(pageBuilder.build());
Page page = pageBuilder.build();
pageBuilder.reset();
if (pollute) {
if (pageCount % 3 == 0) {
pages.add(page);
}
else if (pageCount % 3 == 1) {
// rle page
Block[] blocks = new Block[page.getChannelCount()];
for (int channel = 0; channel < blocks.length; ++channel) {
blocks[channel] = new RunLengthEncodedBlock(page.getBlock(channel).getSingleValueBlock(0), page.getPositionCount());
}
pages.add(new Page(blocks));
}
else {
// dictionary page
int[] positions = IntStream.range(0, page.getPositionCount()).toArray();
Block[] blocks = new Block[page.getChannelCount()];
for (int channel = 0; channel < page.getChannelCount(); ++channel) {
blocks[channel] = new DictionaryBlock(page.getBlock(channel), positions);
}
pages.add(new Page(blocks));
}
}
else {
pages.add(page);
}
pageCount++;
}
}
pages.add(pageBuilder.build());
Expand Down Expand Up @@ -294,7 +325,7 @@ public static class BaselinePagesData
@Setup
public void setup()
{
pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled);
pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false);
}

public List<Page> getPages()
Expand All @@ -320,7 +351,12 @@ public static class SingleChannelBenchmarkData
@Setup
public void setup()
{
pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled);
setup(false);
}

public void setup(boolean pollute)
{
pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled, pollute);
types = Collections.nCopies(1, BIGINT);
channels = new int[1];
for (int i = 0; i < 1; i++) {
Expand Down Expand Up @@ -376,7 +412,7 @@ public void setup()
break;
case "BIGINT":
types = Collections.nCopies(channelCount, BIGINT);
pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled);
pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false);
break;
default:
throw new UnsupportedOperationException("Unsupported dataType");
Expand Down Expand Up @@ -414,6 +450,16 @@ private static JoinCompiler getJoinCompiler()
return new JoinCompiler(TYPE_OPERATORS);
}

static {
Comment thread
sopel39 marked this conversation as resolved.
// pollute BigintGroupByHash profile by different block types
SingleChannelBenchmarkData singleChannelBenchmarkData = new SingleChannelBenchmarkData();
singleChannelBenchmarkData.setup(true);
BenchmarkGroupByHash hash = new BenchmarkGroupByHash();
for (int i = 0; i < 5; ++i) {
hash.bigintGroupByHash(singleChannelBenchmarkData);
}
}

public static void main(String[] args)
throws RunnerException
{
Expand All @@ -431,6 +477,6 @@ public static void main(String[] args)
.withOptions(optionsBuilder -> optionsBuilder
.addProfiler(GCProfiler.class)
.jvmArgs("-Xmx10g"))
.run();
.run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.DictionaryId;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.JoinCompiler;
Expand Down Expand Up @@ -95,6 +96,34 @@ public void testAddPage()
}
}

@Test
public void testRunLengthEncodedBigintGroupByHash()
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY);
Block block = BlockAssertions.createLongsBlock(0L);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block);
Page page = new Page(
new RunLengthEncodedBlock(block, 2),
new RunLengthEncodedBlock(hashBlock, 2));

groupByHash.addPage(page).process();

assertEquals(groupByHash.getGroupCount(), 1);

Work<GroupByIdBlock> work = groupByHash.getGroupIds(page);
work.process();
GroupByIdBlock groupIds = work.getResult();

assertEquals(groupIds.getGroupCount(), 1);
assertEquals(groupIds.getPositionCount(), 2);
assertEquals(groupIds.getGroupId(0), 0);
assertEquals(groupIds.getGroupId(1), 0);

List<Block> children = groupIds.getChildren();
assertEquals(children.size(), 1);
assertTrue(children.get(0) instanceof RunLengthEncodedBlock);
}

@Test
public void testNullGroup()
{
Expand Down