diff --git a/core/trino-main/src/main/java/io/trino/operator/AppendOnlyVariableWidthData.java b/core/trino-main/src/main/java/io/trino/operator/AppendOnlyVariableWidthData.java index 5b785f7f5faf..d595ec764939 100644 --- a/core/trino-main/src/main/java/io/trino/operator/AppendOnlyVariableWidthData.java +++ b/core/trino-main/src/main/java/io/trino/operator/AppendOnlyVariableWidthData.java @@ -138,6 +138,27 @@ public byte[] getChunk(byte[] pointer, int pointerOffset) return chunks.get(chunkIndex); } + public void freeChunksBefore(byte[] pointer, int pointerOffset) + { + int chunkIndex = getChunkIndex(pointer, pointerOffset); + if (chunks.isEmpty()) { + verify(chunkIndex == 0); + return; + } + checkIndex(chunkIndex, chunks.size()); + // Release any previous chunks until a null chunk is encountered, which means it and any previous + // batches have already been released + int releaseIndex = chunkIndex - 1; + while (releaseIndex >= 0) { + byte[] releaseChunk = chunks.set(releaseIndex, null); + if (releaseChunk == null) { + break; + } + chunksRetainedSizeInBytes -= sizeOf(releaseChunk); + releaseIndex--; + } + } + // growth factor for each chunk doubles up to 512KB, then increases by 1.5x for each chunk after that private static long nextChunkSize(long previousChunkSize) { diff --git a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java index 8a81bd23489e..875504273a2f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java @@ -119,6 +119,13 @@ public int getGroupCount() return nextGroupId; } + @Override + public void startReleasingOutput() + { + dictionaryLookBack = null; + currentPageSizeInBytes = 0; + } + @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder) { diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java index 746499d2c053..b1cf4246bdbc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java @@ -129,6 +129,16 @@ public int getGroupCount() return flatHash.size(); } + @Override + public void startReleasingOutput() + { + currentHashes = null; + dictionaryLookBack = null; + Arrays.fill(currentBlocks, null); + currentPageSizeInBytes = 0; + flatHash.startReleasingOutput(); + } + @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder) { @@ -383,11 +393,10 @@ public AddDictionaryPageWork(Block[] blocks) { verify(canProcessDictionary(blocks), "invalid call to addDictionaryPage"); this.dictionaryBlock = (DictionaryBlock) blocks[0]; - - this.dictionaries = Arrays.stream(blocks) - .map(block -> (DictionaryBlock) block) - .map(DictionaryBlock::getDictionary) - .toArray(Block[]::new); + this.dictionaries = blocks; + for (int i = 0; i < dictionaries.length; i++) { + dictionaries[i] = ((DictionaryBlock) dictionaries[i]).getDictionary(); + } updateDictionaryLookBack(dictionaries[0]); } @@ -500,7 +509,7 @@ class GetNonDictionaryGroupIdsWork public GetNonDictionaryGroupIdsWork(Block[] blocks) { this.blocks = blocks; - this.groupIds = new int[currentBlocks[0].getPositionCount()]; + this.groupIds = new int[blocks[0].getPositionCount()]; } @Override @@ -610,13 +619,12 @@ public GetDictionaryGroupIdsWork(Block[] blocks) verify(canProcessDictionary(blocks), "invalid call to processDictionary"); this.dictionaryBlock = (DictionaryBlock) blocks[0]; - this.groupIds = new int[dictionaryBlock.getPositionCount()]; - - this.dictionaries = Arrays.stream(blocks) - .map(block -> (DictionaryBlock) block) - .map(DictionaryBlock::getDictionary) - .toArray(Block[]::new); + this.dictionaries = blocks; + for (int i = 0; i < dictionaries.length; i++) { + dictionaries[i] = ((DictionaryBlock) dictionaries[i]).getDictionary(); + } updateDictionaryLookBack(dictionaries[0]); + this.groupIds = new int[dictionaryBlock.getPositionCount()]; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java index c18414ee04fa..d3136d0ae6e6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java @@ -22,6 +22,7 @@ import java.util.Arrays; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfUnchecked; import static io.airlift.slice.SizeOf.instanceSize; import static io.airlift.slice.SizeOf.sizeOf; @@ -120,8 +121,8 @@ public FlatHash(FlatHash other) this.mask = other.mask; this.nextGroupId = other.nextGroupId; this.maxFill = other.maxFill; - this.control = Arrays.copyOf(other.control, other.control.length); - this.groupIdsByHash = Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length); + this.control = other.control == null ? null : Arrays.copyOf(other.control, other.control.length); + this.groupIdsByHash = other.groupIdsByHash == null ? null : Arrays.copyOf(other.groupIdsByHash, other.groupIdsByHash.length); this.fixedSizeRecords = Arrays.stream(other.fixedSizeRecords) .map(fixedSizeRecords -> fixedSizeRecords == null ? null : Arrays.copyOf(fixedSizeRecords, fixedSizeRecords.length)) .toArray(byte[][]::new); @@ -149,13 +150,32 @@ public int getCapacity() return capacity; } + /** + * Releases memory associated with the hash table which is no longer necessary to produce output. Subsequent + * calls to insert new elements are rejected, and calls to {@link FlatHash#appendTo(int, BlockBuilder[])} will + * incrementally release memory associated with prior groupId values assuming that the caller will only call into + * the method to produce output in a sequential fashion. + */ + public void startReleasingOutput() + { + checkState(!isReleasingOutput(), "already releasing output"); + control = null; + groupIdsByHash = null; + } + + private boolean isReleasingOutput() + { + return control == null; + } + public long hashPosition(int groupId) { - if (groupId < 0) { - throw new IllegalArgumentException("groupId is negative"); + if (groupId < 0 || groupId >= nextGroupId) { + throw new IllegalArgumentException("groupId out of range: " + groupId); } byte[] fixedSizeRecords = getFixedSizeRecords(groupId); int fixedRecordOffset = getFixedRecordOffset(groupId); + checkState(!isReleasingOutput() || fixedSizeRecords != null, "groupId already released"); if (cacheHashValue) { return (long) LONG_HANDLE.get(fixedSizeRecords, fixedRecordOffset); } @@ -178,7 +198,8 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders) { checkArgument(groupId < nextGroupId, "groupId out of range"); - byte[] fixedSizeRecords = getFixedSizeRecords(groupId); + int recordGroupIndex = recordGroupIndexForGroupId(groupId); + byte[] fixedSizeRecords = this.fixedSizeRecords[recordGroupIndex]; int recordOffset = getFixedRecordOffset(groupId); byte[] variableWidthChunk = null; @@ -194,6 +215,19 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders) variableWidthChunk, variableChunkOffset, blockBuilders); + + // Release memory from the previous fixed size records batch + if (isReleasingOutput() && recordOffset == 0 && recordGroupIndex > 0) { + byte[] releasedRecords = this.fixedSizeRecords[recordGroupIndex - 1]; + this.fixedSizeRecords[recordGroupIndex - 1] = null; + if (releasedRecords == null) { + throw new IllegalStateException("already released previous record batch"); + } + fixedRecordGroupsRetainedSize -= sizeOf(releasedRecords); + if (variableWidthData != null) { + variableWidthData.freeChunksBefore(fixedSizeRecords, recordOffset + variableWidthOffset); + } + } } public void computeHashes(Block[] blocks, long[] hashes, int offset, int length) @@ -228,6 +262,7 @@ public int putIfAbsent(Block[] blocks, int position, long hash) private int getIndex(Block[] blocks, int position, long hash) { + checkState(!isReleasingOutput(), "already releasing output"); byte hashPrefix = (byte) (hash & 0x7F | 0x80); int bucket = bucket((int) (hash >> 7)); @@ -328,6 +363,7 @@ private void setControl(int index, byte hashPrefix) public boolean ensureAvailableCapacity(int batchSize) { + checkState(!isReleasingOutput(), "already releasing output"); long requiredMaxFill = nextGroupId + batchSize; if (requiredMaxFill >= maxFill) { long minimumRequiredCapacity = (requiredMaxFill + 1) * 16 / 15; diff --git a/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java index 7807c66a4065..fa51578956d5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java @@ -105,6 +105,14 @@ static GroupByHash createGroupByHash( void appendValuesTo(int groupId, PageBuilder pageBuilder); + /** + * Signals that no more entries will be inserted, and that only calls to {@link GroupByHash#appendValuesTo(int, PageBuilder)} + * with sequential groupId values will be observed after this point, allowing the implementation to potentially + * release memory associated with structures required for inserts or associated with values that have already been + * output. + */ + void startReleasingOutput(); + Work addPage(Page page); /** diff --git a/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java index 34aacb072b2d..5aaf9aaccc57 100644 --- a/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java @@ -50,6 +50,12 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder) throw new UnsupportedOperationException("NoChannelGroupByHash does not support appendValuesTo"); } + @Override + public void startReleasingOutput() + { + throw new UnsupportedOperationException("NoChannelGroupByHash does not support startReleasingOutput"); + } + @Override public Work addPage(Page page) { diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index b9b9b83303f6..8f3d47ef5e90 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -231,12 +231,18 @@ public WorkProcessor buildResult() for (GroupedAggregator groupedAggregator : groupedAggregators) { groupedAggregator.prepareFinal(); } - return buildResult(consecutiveGroupIds(), new PageBuilder(buildTypes()), false); + // Only incrementally release memory for final aggregations, since partial aggregations have a fixed + // memory limit and can be expected to fully flush and release their output quickly + boolean releaseMemoryOnOutput = !partial; + if (releaseMemoryOnOutput) { + groupByHash.startReleasingOutput(); + } + return buildResult(consecutiveGroupIds(), new PageBuilder(buildTypes()), false, releaseMemoryOnOutput); } public WorkProcessor buildSpillResult() { - return buildResult(hashSortedGroupIds(), new PageBuilder(buildSpillTypes()), true); + return buildResult(hashSortedGroupIds(), new PageBuilder(buildSpillTypes()), true, false); } public List buildSpillTypes() @@ -256,7 +262,7 @@ public int getCapacity() return groupByHash.getCapacity(); } - private WorkProcessor buildResult(IntIterator groupIds, PageBuilder pageBuilder, boolean appendRawHash) + private WorkProcessor buildResult(IntIterator groupIds, PageBuilder pageBuilder, boolean appendRawHash, boolean releaseMemoryOnOutput) { int rawHashIndex = groupByChannels.length + groupedAggregators.size(); return WorkProcessor.create(() -> { @@ -283,6 +289,11 @@ private WorkProcessor buildResult(IntIterator groupIds, PageBuilder pageBu } } + // Update memory usage after producing each page of output + if (releaseMemoryOnOutput) { + updateMemory(); + } + return ProcessState.ofResult(pageBuilder.build()); }); } diff --git a/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java index 644403957f51..46a7459cd4db 100644 --- a/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java @@ -54,6 +54,12 @@ public int getGroupCount() return maxGroupId + 1; } + @Override + public void startReleasingOutput() + { + throw new UnsupportedOperationException("Not yet supported"); + } + @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java index 0f927dd1cbab..b030091605a4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java @@ -45,6 +45,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestGroupByHash { @@ -314,6 +315,60 @@ public void testUpdateMemoryBigint() assertThat(rehashCount.get()).isEqualTo(2 * BIGINT_EXPECTED_REHASH); } + @Test + public void testReleaseMemoryOnOutput() + { + Type type = VARCHAR; + // values expands into multiple FlatGroupByHash fixed record groups + Block valuesBlock = createStringSequenceBlock(0, 1_000_000); + + GroupByHash groupByHash = createGroupByHash(ImmutableList.of(type), selectGroupByHashMode(false, ImmutableList.of(type)), 10_000, false, new FlatHashStrategyCompiler(new TypeOperators()), () -> true); + assertThat(groupByHash.addPage(new Page(valuesBlock)).process()).isTrue(); + assertThat(groupByHash.getGroupCount()).isEqualTo(valuesBlock.getPositionCount()); + + long memoryUsageAfterInput = groupByHash.getEstimatedSize(); + groupByHash.startReleasingOutput(); + // memory usage should have decreased from dropping the hash table + long memoryUsageAfterReleasingOutput = groupByHash.getEstimatedSize(); + // single immediate release of memory for the control and groupId by hash values + assertThat(memoryUsageAfterReleasingOutput).isLessThan(memoryUsageAfterInput); + + // no more inputs accepted after switching to releasing output + assertThatThrownBy(() -> groupByHash.addPage(new Page(valuesBlock)).process()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("already releasing output"); + assertThatThrownBy(() -> groupByHash.startReleasingOutput()) + .isInstanceOf(IllegalStateException.class) + .hasMessage("already releasing output"); + + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type)); + int groupId = 0; + // FlatGroupByHash first 1024 records are within the first record group + for (; groupId < 1024; groupId++) { + groupByHash.appendValuesTo(groupId, pageBuilder); + pageBuilder.declarePosition(); + } + pageBuilder.build(); + // No memory released yet after completing the first group + assertThat(groupByHash.getEstimatedSize()).isEqualTo(memoryUsageAfterReleasingOutput); + + groupByHash.appendValuesTo(groupId++, pageBuilder); + pageBuilder.declarePosition(); + // Memory released + long memoryUsageAfterFirstRelease = groupByHash.getEstimatedSize(); + assertThat(memoryUsageAfterFirstRelease).isLessThan(memoryUsageAfterReleasingOutput); + assertThatThrownBy(() -> groupByHash.getRawHash(0)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("groupId already released"); + + for (; groupId < valuesBlock.getPositionCount(); groupId++) { + groupByHash.appendValuesTo(groupId, pageBuilder); + pageBuilder.declarePosition(); + } + // More memory released + assertThat(groupByHash.getEstimatedSize()).isLessThan(memoryUsageAfterFirstRelease); + } + @Test public void testMemoryReservationYield() {