diff --git a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java index ad7ec9063cc4..342c39039d3b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java @@ -31,6 +31,8 @@ import it.unimi.dsi.fastutil.objects.ObjectArrayList; import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -682,11 +684,16 @@ public Void getResult() } class AddLowCardinalityDictionaryPageWork - extends LowCardinalityDictionaryWork + implements Work { + private final Page page; + @Nullable + private int[] combinationIdToPosition; + private int nextCombinationId; + public AddLowCardinalityDictionaryPageWork(Page page) { - super(page); + this.page = requireNonNull(page, "page is null"); } @Override @@ -698,18 +705,20 @@ public boolean process() return false; } - int[] combinationIdToPosition = new int[maxCardinality]; - Arrays.fill(combinationIdToPosition, -1); - calculateCombinationIdsToPositionMapping(combinationIdToPosition); + if (combinationIdToPosition == null) { + combinationIdToPosition = calculateCombinationIdToPositionMapping(page); + } // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - for (int i = 0; i < maxCardinality; i++) { - if (needRehash()) { - return false; - } - if (combinationIdToPosition[i] != -1) { - putIfAbsent(combinationIdToPosition[i], page); + for (int combinationId = nextCombinationId; combinationId < combinationIdToPosition.length; combinationId++) { + int position = combinationIdToPosition[combinationId]; + if (position != -1) { + if (needRehash()) { + nextCombinationId = combinationId; + return false; + } + putIfAbsent(position, page); } } return true; @@ -816,14 +825,20 @@ public GroupByIdBlock getResult() @VisibleForTesting class GetLowCardinalityDictionaryGroupIdsWork - extends LowCardinalityDictionaryWork + implements Work { + private final Page page; private final long[] groupIds; + @Nullable + private short[] positionToCombinationId; + @Nullable + private int[] combinationIdToGroupId; + private int nextPosition; private boolean finished; public GetLowCardinalityDictionaryGroupIdsWork(Page page) { - super(page); + this.page = requireNonNull(page, "page is null"); groupIds = new long[page.getPositionCount()]; } @@ -836,27 +851,27 @@ public boolean process() return false; } - int positionCount = page.getPositionCount(); - int[] combinationIdToPosition = new int[maxCardinality]; - Arrays.fill(combinationIdToPosition, -1); - short[] positionToCombinationId = calculateCombinationIdsToPositionMapping(combinationIdToPosition); - int[] combinationIdToGroupId = new int[maxCardinality]; + if (positionToCombinationId == null) { + positionToCombinationId = new short[groupIds.length]; + int maxCardinality = calculatePositionToCombinationIdMapping(page, positionToCombinationId); + combinationIdToGroupId = new int[maxCardinality]; + Arrays.fill(combinationIdToGroupId, -1); + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - for (int i = 0; i < maxCardinality; i++) { - if (needRehash()) { - return false; - } - if (combinationIdToPosition[i] != -1) { - combinationIdToGroupId[i] = putIfAbsent(combinationIdToPosition[i], page); - } - else { - combinationIdToGroupId[i] = -1; + for (int position = nextPosition; position < groupIds.length; position++) { + short combinationId = positionToCombinationId[position]; + int groupId = combinationIdToGroupId[combinationId]; + if (groupId == -1) { + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + if (needRehash()) { + nextPosition = position; + return false; + } + groupId = putIfAbsent(position, page); + combinationIdToGroupId[combinationId] = groupId; } - } - for (int i = 0; i < positionCount; i++) { - groupIds[i] = combinationIdToGroupId[positionToCombinationId[i]]; + groupIds[position] = groupId; } return true; } @@ -980,55 +995,53 @@ public GroupByIdBlock getResult() } } - private abstract class LowCardinalityDictionaryWork - implements Work + /** + * Returns an array containing a position that corresponds to the low cardinality + * dictionary combinationId, or a value of -1 if no position exists within the page + * for that combinationId. + */ + private int[] calculateCombinationIdToPositionMapping(Page page) { - protected final Page page; - protected final int maxCardinality; - protected final int[] dictionarySizes; - protected final DictionaryBlock[] blocks; + short[] positionToCombinationId = new short[page.getPositionCount()]; + int maxCardinality = calculatePositionToCombinationIdMapping(page, positionToCombinationId); - public LowCardinalityDictionaryWork(Page page) - { - this.page = requireNonNull(page, "page is null"); - dictionarySizes = new int[channels.length]; - blocks = new DictionaryBlock[channels.length]; - int maxCardinality = 1; - for (int i = 0; i < channels.length; i++) { - Block block = page.getBlock(channels[i]); - verify(block instanceof DictionaryBlock, "Only dictionary blocks are supported"); - blocks[i] = (DictionaryBlock) block; - int blockPositionCount = blocks[i].getDictionary().getPositionCount(); - dictionarySizes[i] = blockPositionCount; - maxCardinality *= blockPositionCount; - } - this.maxCardinality = maxCardinality; + int[] combinationIdToPosition = new int[maxCardinality]; + Arrays.fill(combinationIdToPosition, -1); + for (int position = 0; position < positionToCombinationId.length; position++) { + combinationIdToPosition[positionToCombinationId[position]] = position; } + return combinationIdToPosition; + } - /** - * Returns combinations of all dictionaries ids for every position and populates - * samplePositions array with a single occurrence of every used combination - */ - protected short[] calculateCombinationIdsToPositionMapping(int[] combinationIdToPosition) - { - int positionCount = page.getPositionCount(); - // short arrays improve performance compared to int - short[] combinationIds = new short[positionCount]; - - for (int i = 0; i < positionCount; i++) { - combinationIds[i] = (short) blocks[0].getId(i); - } - for (int j = 1; j < channels.length; j++) { - for (int i = 0; i < positionCount; i++) { - combinationIds[i] *= dictionarySizes[j]; - combinationIds[i] += blocks[j].getId(i); + /** + * Returns the number of combinations of all dictionary ids in input page blocks and populates + * positionToCombinationIds with the combinationId for each position in the input Page + */ + private int calculatePositionToCombinationIdMapping(Page page, short[] positionToCombinationIds) + { + checkArgument(positionToCombinationIds.length == page.getPositionCount()); + + int maxCardinality = 1; + for (int channel = 0; channel < channels.length; channel++) { + Block block = page.getBlock(channels[channel]); + verify(block instanceof DictionaryBlock, "Only dictionary blocks are supported"); + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + int dictionarySize = dictionaryBlock.getDictionary().getPositionCount(); + maxCardinality *= dictionarySize; + if (channel == 0) { + for (int position = 0; position < positionToCombinationIds.length; position++) { + positionToCombinationIds[position] = (short) dictionaryBlock.getId(position); } } - - for (int i = 0; i < positionCount; i++) { - combinationIdToPosition[combinationIds[i]] = i; + else { + for (int position = 0; position < positionToCombinationIds.length; position++) { + short combinationId = positionToCombinationIds[position]; + combinationId *= dictionarySize; + combinationId += dictionaryBlock.getId(position); + positionToCombinationIds[position] = combinationId; + } } - return combinationIds; } + return maxCardinality; } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java index 471b225d86e1..778c4ba74046 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.Type; @@ -581,6 +582,43 @@ public void testLowCardinalityDictionariesGetGroupIds() assertThat(lowCardinalityResults.getGroupCount()).isEqualTo(results.getGroupCount()); } + @Test + public void testLowCardinalityDictionariesProperGroupIdOrder() + { + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT), + new int[] {0, 1}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + + Block dictionary = new LongArrayBlock(2, Optional.empty(), new long[] {0, 1}); + int[] ids = new int[32]; + for (int i = 0; i < 16; i++) { + ids[i] = 1; + } + Block block1 = new DictionaryBlock(dictionary, ids); + Block block2 = new DictionaryBlock(dictionary, ids); + + Page page = new Page(block1, block2); + + Work work = groupByHash.getGroupIds(page); + assertThat(work).isInstanceOf(GetLowCardinalityDictionaryGroupIdsWork.class); + + work.process(); + GroupByIdBlock results = work.getResult(); + // Records with group id '0' should come before '1' despite being in the end of the block + for (int i = 0; i < 16; i++) { + assertThat(results.getGroupId(i)).isEqualTo(0); + } + for (int i = 16; i < 32; i++) { + assertThat(results.getGroupId(i)).isEqualTo(1); + } + } + @Test public void testProperWorkTypesSelected() {