diff --git a/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java b/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java index 8f147867c2d4..7269055b5012 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java +++ b/core/trino-main/src/main/java/io/trino/operator/PagesIndex.java @@ -671,7 +671,12 @@ public long getEstimatedMemoryRequiredToCreateLookupSource( getSingleBigintJoinChannel(joinChannels, types), hashArraySizeSupplier); // PageIndex is retained during LookupSource creation, hence any extra memory retained by the PagesIndex must be accounted here - long pagesIndexAdditionalRetainedSizeInBytes = INSTANCE_SIZE + sizeOf(positionCounts.elements()); + long pagesIndexAdditionalRetainedSizeInBytes = getExtraPagesIndexMemoryWithLookupSourceBuild(); return pagesIndexAdditionalRetainedSizeInBytes + lookupSourceEstimatedRetainedSizeInBytes; } + + public long getExtraPagesIndexMemoryWithLookupSourceBuild() + { + return INSTANCE_SIZE + sizeOf(positionCounts.elements()); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java index 3ffeb7c29246..52bd2d6bd548 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/HashBuilderOperator.java @@ -22,6 +22,7 @@ import io.airlift.concurrent.MoreFutures; import io.airlift.log.Logger; import io.airlift.units.DataSize; +import io.trino.memory.context.CoarseGrainLocalMemoryContext; import io.trino.memory.context.LocalMemoryContext; import io.trino.operator.DriverContext; import io.trino.operator.HashArraySizeSupplier; @@ -57,6 +58,7 @@ import static io.airlift.concurrent.MoreFutures.checkSuccess; import static io.airlift.concurrent.MoreFutures.getDone; import static io.airlift.units.DataSize.succinctBytes; +import static io.trino.memory.context.CoarseGrainLocalMemoryContext.DEFAULT_GRANULARITY; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -149,7 +151,8 @@ public HashBuilderOperator createOperator(DriverContext driverContext) pagesIndexFactory, spillEnabled, singleStreamSpillerFactory, - hashArraySizeSupplier); + hashArraySizeSupplier, + DEFAULT_GRANULARITY); } @Override @@ -235,6 +238,7 @@ public enum State private Optional spiller = Optional.empty(); private ListenableFuture spillInProgress = immediateFuture(DataSize.ofBytes(0)); private Optional>> unspillInProgress = Optional.empty(); + private boolean unspilledPagesAdded; @Nullable private LookupSourceSupplier lookupSourceSupplier; private OptionalLong lookupSourceChecksum = OptionalLong.empty(); @@ -255,7 +259,8 @@ public HashBuilderOperator( PagesIndex.Factory pagesIndexFactory, boolean spillEnabled, SingleStreamSpillerFactory singleStreamSpillerFactory, - HashArraySizeSupplier hashArraySizeSupplier) + HashArraySizeSupplier hashArraySizeSupplier, + long memorySyncGranularity) { requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); @@ -264,8 +269,8 @@ public HashBuilderOperator( this.filterFunctionFactory = filterFunctionFactory; this.sortChannel = sortChannel; this.searchFunctionFactories = searchFunctionFactories; - this.localUserMemoryContext = operatorContext.localUserMemoryContext(); - this.localRevocableMemoryContext = operatorContext.localRevocableMemoryContext(); + this.localUserMemoryContext = new CoarseGrainLocalMemoryContext(operatorContext.localUserMemoryContext(), memorySyncGranularity); + this.localRevocableMemoryContext = new CoarseGrainLocalMemoryContext(operatorContext.localRevocableMemoryContext(), memorySyncGranularity); this.index = pagesIndexFactory.newPagesIndex(lookupSourceFactory.getTypes(), expectedPositions); this.lookupSourceFactory = lookupSourceFactory; @@ -373,6 +378,7 @@ public ListenableFuture startMemoryRevoke() long indexSizeAfterCompaction = index.getEstimatedSize().toBytes(); if (indexSizeAfterCompaction < indexSizeBeforeCompaction * INDEX_COMPACTION_ON_REVOCATION_TARGET) { finishMemoryRevoke = Optional.of(() -> {}); + localRevocableMemoryContext.setBytes(indexSizeAfterCompaction); return immediateVoidFuture(); } @@ -495,12 +501,30 @@ private void finishInput() return; } + long memoryRequired = index.getEstimatedMemoryRequiredToCreateLookupSource( + hashArraySizeSupplier, + sortChannel, + hashChannels); + + ListenableFuture reserved; + if (spillEnabled) { + reserved = localRevocableMemoryContext.setBytes(memoryRequired); + } + else { + reserved = localUserMemoryContext.setBytes(memoryRequired); + } + + if (!reserved.isDone()) { + // wait for memory + return; + } + LookupSourceSupplier partition = buildLookupSource(); if (spillEnabled) { - localRevocableMemoryContext.setBytes(partition.get().getInMemorySizeInBytes()); + localRevocableMemoryContext.setBytes(partition.get().getInMemorySizeInBytes() + index.getExtraPagesIndexMemoryWithLookupSourceBuild()); } else { - localUserMemoryContext.setBytes(partition.get().getInMemorySizeInBytes()); + localUserMemoryContext.setBytes(partition.get().getInMemorySizeInBytes() + index.getExtraPagesIndexMemoryWithLookupSourceBuild()); } lookupSourceNotNeeded = Optional.of(lookupSourceFactory.lendPartitionLookupSource(partitionIndex, partition)); @@ -545,7 +569,10 @@ private void unspillLookupSourceIfRequested() verify(unspillInProgress.isEmpty()); long spilledPagesInMemorySize = getSpiller().getSpilledPagesInMemorySize(); - localUserMemoryContext.setBytes(spilledPagesInMemorySize + index.getEstimatedSize().toBytes()); + if (!localUserMemoryContext.setBytes(spilledPagesInMemorySize + index.getEstimatedSize().toBytes()).isDone()) { + // wait for memory + return; + } long unspillStartNanos = System.nanoTime(); unspillInProgress = Optional.of(getSpiller().getAllSpilledPages()); addSuccessCallback(unspillInProgress.get(), ignored -> { @@ -554,45 +581,59 @@ private void unspillLookupSourceIfRequested() }); state = State.INPUT_UNSPILLING; + unspilledPagesAdded = false; } private void finishLookupSourceUnspilling() { checkState(state == State.INPUT_UNSPILLING); - if (!unspillInProgress.get().isDone()) { - // Pages have not been unspilled yet. - return; - } - // Use Queue so that Pages already consumed by Index are not retained by us. - Queue pages = new ArrayDeque<>(getDone(unspillInProgress.get())); - unspillInProgress = Optional.empty(); - long sizeOfUnspilledPages = pages.stream() - .mapToLong(Page::getSizeInBytes) - .sum(); - long retainedSizeOfUnspilledPages = pages.stream() - .mapToLong(Page::getRetainedSizeInBytes) - .sum(); - log.debug( - "Unspilling for operator %s, unspilled partition %d, sizeOfUnspilledPages %s, retainedSizeOfUnspilledPages %s", - operatorContext, - partitionIndex, - succinctBytes(sizeOfUnspilledPages), - succinctBytes(retainedSizeOfUnspilledPages)); - localUserMemoryContext.setBytes(retainedSizeOfUnspilledPages + index.getEstimatedSize().toBytes()); - - while (!pages.isEmpty()) { - Page next = pages.remove(); - index.addPage(next); - // There is no attempt to compact index, since unspilled pages are unlikely to have blocks with retained size > logical size. - retainedSizeOfUnspilledPages -= next.getRetainedSizeInBytes(); + if (!unspilledPagesAdded) { + if (!unspillInProgress.get().isDone()) { + // Pages have not been unspilled yet. + return; + } + + Queue pages = new ArrayDeque<>(getDone(unspillInProgress.get())); + unspillInProgress = Optional.empty(); + long sizeOfUnspilledPages = pages.stream() + .mapToLong(Page::getSizeInBytes) + .sum(); + long retainedSizeOfUnspilledPages = pages.stream() + .mapToLong(Page::getRetainedSizeInBytes) + .sum(); + log.debug( + "Unspilling for operator %s, unspilled partition %d, sizeOfUnspilledPages %s, retainedSizeOfUnspilledPages %s", + operatorContext, + partitionIndex, + succinctBytes(sizeOfUnspilledPages), + succinctBytes(retainedSizeOfUnspilledPages)); localUserMemoryContext.setBytes(retainedSizeOfUnspilledPages + index.getEstimatedSize().toBytes()); + + while (!pages.isEmpty()) { + Page next = pages.remove(); + index.addPage(next); + // There is no attempt to compact index, since unspilled pages are unlikely to have blocks with retained size > logical size. + retainedSizeOfUnspilledPages -= next.getRetainedSizeInBytes(); + localUserMemoryContext.setBytes(retainedSizeOfUnspilledPages + index.getEstimatedSize().toBytes()); + } + + unspilledPagesAdded = true; + } + + ListenableFuture reserved = localUserMemoryContext.setBytes(index.getEstimatedMemoryRequiredToCreateLookupSource( + hashArraySizeSupplier, + sortChannel, + hashChannels)); + if (!reserved.isDone()) { + // Wait for memory + return; } LookupSourceSupplier partition = buildLookupSource(); lookupSourceChecksum.ifPresent(checksum -> checkState(partition.checksum() == checksum, "Unspilled lookupSource checksum does not match original one")); - localUserMemoryContext.setBytes(partition.get().getInMemorySizeInBytes()); + localUserMemoryContext.setBytes(partition.get().getInMemorySizeInBytes() + index.getExtraPagesIndexMemoryWithLookupSourceBuild()); spilledLookupSourceHandle.setLookupSource(partition); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java index 51cc8ffd225e..01546017f55a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java @@ -31,12 +31,15 @@ import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.UniformNodeSelectorFactory; +import io.trino.memory.context.LocalMemoryContext; import io.trino.metadata.InMemoryNodeManager; import io.trino.operator.Driver; import io.trino.operator.DriverContext; import io.trino.operator.Operator; import io.trino.operator.OperatorAssertion; +import io.trino.operator.OperatorContext; import io.trino.operator.OperatorFactory; +import io.trino.operator.PagesIndex; import io.trino.operator.TaskContext; import io.trino.operator.ValuesOperator.ValuesOperatorFactory; import io.trino.operator.index.PageBuffer; @@ -47,6 +50,7 @@ import io.trino.plugin.base.metrics.TDigestHistogram; import io.trino.spi.Page; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.metrics.Metrics; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -72,6 +76,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -79,7 +84,10 @@ import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.RowPagesBuilder.rowPagesBuilder; +import static io.trino.SequencePageBuilder.createSequencePage; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.memory.context.CoarseGrainLocalMemoryContext.DEFAULT_GRANULARITY; +import static io.trino.operator.HashArraySizeSupplier.defaultHashArraySizeSupplier; import static io.trino.operator.JoinOperatorType.fullOuterJoin; import static io.trino.operator.JoinOperatorType.innerJoin; import static io.trino.operator.JoinOperatorType.lookupOuterJoin; @@ -118,6 +126,7 @@ public class TestHashJoinOperator private static final SingleStreamSpillerFactory SINGLE_STREAM_SPILLER_FACTORY = new DummySpillerFactory(); private static final PartitioningSpillerFactory PARTITIONING_SPILLER_FACTORY = new GenericPartitioningSpillerFactory(SINGLE_STREAM_SPILLER_FACTORY); private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final long SMALL_MEMORY_POOL_BYTES = DataSize.of(1, DataSize.Unit.MEGABYTE).toBytes(); private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); @@ -1285,6 +1294,205 @@ private void testMemoryLimit(boolean parallelBuild, boolean buildHashEnabled) .hasMessageMatching("Query exceeded per-node memory limit of.*"); } + @Test + public void testHashBuilderFinishInputWaitsForMemory() + { + testHashBuilderFinishInputWaitsForMemory(true); + testHashBuilderFinishInputWaitsForMemory(false); + } + + private void testHashBuilderFinishInputWaitsForMemory(boolean spillEnabled) + { + DriverTestContext contexts = createDriverTestContext(); + OperatorContext operatorContext = contexts.operatorContext; + OperatorContext anotherOperatorContext = contexts.anotherOperatorContext; + ImmutableList types = ImmutableList.of(BIGINT, BIGINT); + PartitionedLookupSourceFactory lookupSourceFactory = new PartitionedLookupSourceFactory( + types, + ImmutableList.of(BIGINT), + ImmutableList.of(BIGINT), + 1, + false, + TYPE_OPERATORS); + try (HashBuilderOperator operator = new HashBuilderOperator( + operatorContext, + lookupSourceFactory, + 0, + ImmutableList.of(0), + ImmutableList.of(1), + OptionalInt.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + 10_000, + new PagesIndex.TestingFactory(false), + spillEnabled, + SINGLE_STREAM_SPILLER_FACTORY, + defaultHashArraySizeSupplier(), + 1)) { + // add enough pages to require memory reservation when finish() is called + for (int i = 0; i < 100; i++) { + operator.addInput(createSequencePage(types, 1)); + } + + // occupy the whole memory pool with another operator so finish() has to wait + anotherOperatorContext.getOperatorMemoryContext().localUserMemoryContext().setBytes(SMALL_MEMORY_POOL_BYTES); + operator.finish(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.CONSUMING_INPUT); + assertThat(operator.isFinished()).isFalse(); + if (spillEnabled) { + assertThat(operatorContext.isWaitingForRevocableMemory()).isNotDone(); + } + else { + assertThat(operatorContext.isWaitingForMemory()).isNotDone(); + } + + // free memory and let finish() proceed + anotherOperatorContext.getOperatorMemoryContext().localUserMemoryContext().setBytes(0); + operator.finish(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.LOOKUP_SOURCE_BUILT); + assertThat(operator.isFinished()).isFalse(); + if (spillEnabled) { + assertThat(operatorContext.isWaitingForRevocableMemory()).isDone(); + } + else { + assertThat(operatorContext.isWaitingForMemory()).isDone(); + } + } + finally { + operatorContext.destroy(); + } + } + + @Test + public void testHashBuilderUnspillWaitsForMemory() + throws Exception + { + DriverTestContext contexts = createDriverTestContext(); + OperatorContext operatorContext = contexts.operatorContext; + OperatorContext anotherOperatorContext = contexts.anotherOperatorContext; + ImmutableList types = ImmutableList.of(BIGINT); + PartitionedLookupSourceFactory lookupSourceFactory = new PartitionedLookupSourceFactory( + types, + types, + ImmutableList.of(BIGINT), + 2, + false, + TYPE_OPERATORS); + + try (HashBuilderOperator operator = new HashBuilderOperator( + operatorContext, + lookupSourceFactory, + 0, + ImmutableList.of(0), + ImmutableList.of(0), + OptionalInt.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + 10_000, + new PagesIndex.TestingFactory(false), + true, + SINGLE_STREAM_SPILLER_FACTORY, + defaultHashArraySizeSupplier(), + 1)) { + for (int i = 0; i < 100; i++) { + operator.addInput(createSequencePage(types, 1)); + } + + // spill the index + revokeMemory(operator); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.SPILLING_INPUT); + operator.finish(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.INPUT_SPILLED); + + // request partition to trigger unspilling + PartitionedConsumption> consumption = lookupSourceFactory.finishProbeOperator(OptionalInt.of(1)).get(); + PartitionedConsumption.Partition> partition = consumption.beginConsumption().next(); + ListenableFuture> lookupSourceFuture = partition.load(); + operator.finish(); + assertThat(operatorContext.isWaitingForMemory()).isDone(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.INPUT_UNSPILLING); + + // block memory so unspilling cannot reserve memory + anotherOperatorContext.getOperatorMemoryContext().localUserMemoryContext().setBytes(SMALL_MEMORY_POOL_BYTES); + operator.finish(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.INPUT_UNSPILLING); + assertThat(operatorContext.isWaitingForMemory()).isNotDone(); + assertThat(lookupSourceFuture).isNotDone(); + + // release memory and continue + anotherOperatorContext.getOperatorMemoryContext().localUserMemoryContext().setBytes(0); + operator.finish(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.INPUT_UNSPILLED_AND_BUILT); + + assertThat(lookupSourceFuture).isDone(); + assertThat(operatorContext.isWaitingForMemory()).isDone(); + + try (LookupSource lookupSource = lookupSourceFuture.get().get()) { + assertThat(lookupSource.getJoinPositionCount()).isEqualTo(100); + } + } + finally { + operatorContext.destroy(); + } + } + + @Test + public void testMemoryRevokeCompactionUpdatesRevocableMemory() + { + DriverTestContext contexts = createDriverTestContext(); + OperatorContext operatorContext = contexts.operatorContext; + ImmutableList types = ImmutableList.of(VARCHAR); + PartitionedLookupSourceFactory lookupSourceFactory = new PartitionedLookupSourceFactory( + types, + types, + ImmutableList.of(VARCHAR), + 2, + false, + TYPE_OPERATORS); + + try (HashBuilderOperator operator = new HashBuilderOperator( + operatorContext, + lookupSourceFactory, + 0, + ImmutableList.of(0), + ImmutableList.of(0), + OptionalInt.empty(), + Optional.empty(), + Optional.empty(), + ImmutableList.of(), + 10_000, + new PagesIndex.TestingFactory(false), + true, + SINGLE_STREAM_SPILLER_FACTORY, + defaultHashArraySizeSupplier(), + DEFAULT_GRANULARITY)) { + // add page to build index + operator.addInput(new Page(new VariableWidthBlock(1, Slices.allocate(100000), new int[] {0, 1}, Optional.empty()))); + + LocalMemoryContext revocableMemoryContext = operatorContext.localRevocableMemoryContext(); + long beforeBytes = revocableMemoryContext.getBytes(); + + // trigger memory revocation which performs compaction + getFutureValue(operator.startMemoryRevoke()); + operator.finishMemoryRevoke(); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.CONSUMING_INPUT); + + long afterBytes = revocableMemoryContext.getBytes(); + assertThat(afterBytes).isLessThan(beforeBytes); + + // subsequent revocation should revoke all memory + getFutureValue(operator.startMemoryRevoke()); + operator.finishMemoryRevoke(); + assertThat(revocableMemoryContext.getBytes()).isEqualTo(0); + assertThat(operator.getState()).isEqualTo(HashBuilderOperator.State.SPILLING_INPUT); + } + finally { + operatorContext.destroy(); + } + } + @Test public void testInnerJoinWithEmptyLookupSource() { @@ -1725,17 +1933,23 @@ private OperatorFactory probeOuterJoinOperatorFactory( TYPE_OPERATORS); } - private static List> product(List> left, List> right) + private DriverTestContext createDriverTestContext() { - List> result = new ArrayList<>(); - for (List l : left) { - for (List r : right) { - result.add(concat(l, r)); - } - } - return result; + TaskContext taskContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) + .setMemoryPoolSize(DataSize.ofBytes(SMALL_MEMORY_POOL_BYTES)) + .build(); + DriverContext driverContext = taskContext + .addPipelineContext(0, false, false, false) + .addDriverContext(); + OperatorContext operatorContext = driverContext + .addOperatorContext(0, new PlanNodeId("0"), HashBuilderOperator.class.getName()); + OperatorContext anotherOperatorContext = driverContext + .addOperatorContext(1, new PlanNodeId("1"), "another operator"); + return new DriverTestContext(taskContext, driverContext, operatorContext, anotherOperatorContext); } + private record DriverTestContext(TaskContext taskContext, DriverContext driverContext, OperatorContext operatorContext, OperatorContext anotherOperatorContext) {} + private static List concat(List initialElements, List moreElements) { return ImmutableList.copyOf(Iterables.concat(initialElements, moreElements));