Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import org.openjdk.jol.info.ClassLayout;

import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -68,6 +70,7 @@ public class BigintGroupByHash
private final LongBigArray valuesByGroupId;

private int nextGroupId;
private DictionaryLookBack dictionaryLookBack;
private long hashCollisions;
private double expectedHashCollisions;

Expand Down Expand Up @@ -166,6 +169,9 @@ public Work<?> addPage(Page page)
if (block instanceof RunLengthEncodedBlock) {
return new AddRunLengthEncodedPageWork((RunLengthEncodedBlock) block);
}
if (block instanceof DictionaryBlock) {
return new AddDictionaryPageWork((DictionaryBlock) block);
}

return new AddPageWork(block);
}
Expand All @@ -178,6 +184,9 @@ public Work<GroupByIdBlock> getGroupIds(Page page)
if (block instanceof RunLengthEncodedBlock) {
return new GetRunLengthEncodedGroupIdsWork((RunLengthEncodedBlock) block);
}
if (block instanceof DictionaryBlock) {
return new GetDictionaryGroupIdsWork((DictionaryBlock) block);
}

return new GetGroupIdsWork(page.getBlock(hashChannel));
}
Expand Down Expand Up @@ -344,6 +353,24 @@ private static int calculateMaxFill(int hashSize)
return maxFill;
}

private void updateDictionaryLookBack(Block dictionary)
{
if (dictionaryLookBack == null || dictionaryLookBack.getDictionary() != dictionary) {
dictionaryLookBack = new DictionaryLookBack(dictionary);
}
}

private int getGroupId(Block dictionary, int positionInDictionary)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you name it in a way that it does suggest it has side effects registerGroupId?

{
if (dictionaryLookBack.isProcessed(positionInDictionary)) {
return dictionaryLookBack.getGroupId(positionInDictionary);
}

int groupId = putIfAbsent(positionInDictionary, dictionary);
dictionaryLookBack.setProcessed(positionInDictionary, groupId);
return groupId;
}

private class AddPageWork
implements Work<Void>
{
Expand Down Expand Up @@ -385,6 +412,50 @@ public Void getResult()
}
}

private class AddDictionaryPageWork
implements Work<Void>
{
private final Block dictionary;
private final DictionaryBlock block;

private int lastPosition;

public AddDictionaryPageWork(DictionaryBlock block)
{
this.block = requireNonNull(block, "block is null");
this.dictionary = block.getDictionary();
updateDictionaryLookBack(dictionary);
}

@Override
public boolean process()
{
int positionCount = block.getPositionCount();
checkState(lastPosition < positionCount, "position count out of bound");

// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
while (lastPosition < positionCount && !needRehash()) {
int positionInDictionary = block.getId(lastPosition);
getGroupId(dictionary, positionInDictionary);
lastPosition++;
}
return lastPosition == positionCount;
}

@Override
public Void getResult()
{
throw new UnsupportedOperationException();
}
}

private class AddRunLengthEncodedPageWork
implements Work<Void>
{
Expand Down Expand Up @@ -475,6 +546,60 @@ public GroupByIdBlock getResult()
}
}

private class GetDictionaryGroupIdsWork
implements Work<GroupByIdBlock>
{
private final BlockBuilder blockBuilder;
private final Block dictionary;
private final DictionaryBlock block;

private boolean finished;
private int lastPosition;

public GetDictionaryGroupIdsWork(DictionaryBlock block)
{
this.block = requireNonNull(block, "block is null");
this.dictionary = block.getDictionary();
updateDictionaryLookBack(dictionary);

// we know the exact size required for the block
this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(block.getPositionCount());
}

@Override
public boolean process()
{
int positionCount = block.getPositionCount();
checkState(lastPosition < positionCount, "position count out of bound");
checkState(!finished);

// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
while (lastPosition < positionCount && !needRehash()) {
int positionInDictionary = block.getId(lastPosition);
int groupId = getGroupId(dictionary, positionInDictionary);
BIGINT.writeLong(blockBuilder, groupId);
lastPosition++;
}
return lastPosition == positionCount;
}

@Override
public GroupByIdBlock getResult()
{
checkState(lastPosition == block.getPositionCount(), "process has not yet finished");
checkState(!finished, "result has produced");
finished = true;
return new GroupByIdBlock(nextGroupId, blockBuilder.build());
}
}

private class GetRunLengthEncodedGroupIdsWork
implements Work<GroupByIdBlock>
{
Expand Down Expand Up @@ -524,4 +649,37 @@ public GroupByIdBlock getResult()
block.getPositionCount()));
}
}

private static final class DictionaryLookBack
{
private final Block dictionary;
private final int[] processed;

public DictionaryLookBack(Block dictionary)
{
this.dictionary = dictionary;
this.processed = new int[dictionary.getPositionCount()];
Arrays.fill(processed, -1);
}

public Block getDictionary()
{
return dictionary;
}

public int getGroupId(int position)
{
return processed[position];
}

public boolean isProcessed(int position)
{
return processed[position] != -1;
}

public void setProcessed(int position, int groupId)
{
processed[position] = groupId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,33 @@ public void testRunLengthEncodedBigintGroupByHash()
assertTrue(children.get(0) instanceof RunLengthEncodedBlock);
}

@Test
public void testDictionaryBigintGroupByHash()
{
GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY);
Block block = BlockAssertions.createLongsBlock(0L, 1L);
Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block);
int[] ids = new int[] {0, 0, 1, 1};
Page page = new Page(
new DictionaryBlock(block, ids),
new DictionaryBlock(hashBlock, ids));

groupByHash.addPage(page).process();

assertEquals(groupByHash.getGroupCount(), 2);

Work<GroupByIdBlock> work = groupByHash.getGroupIds(page);
work.process();
GroupByIdBlock groupIds = work.getResult();

assertEquals(groupIds.getGroupCount(), 2);
assertEquals(groupIds.getPositionCount(), 4);
assertEquals(groupIds.getGroupId(0), 0);
assertEquals(groupIds.getGroupId(1), 0);
assertEquals(groupIds.getGroupId(2), 1);
assertEquals(groupIds.getGroupId(3), 1);
}

@Test
public void testNullGroup()
{
Expand Down