diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index 33ece49e03702..fdfa9cacf0c3b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -575,7 +575,7 @@ private void initializeAggregationBuilderIfNeeded() maxPartialMemory, joinCompiler, true, - useSystemMemory ? ReserveType.SYSTEM : ReserveType.USER); + useSystemMemory); } else { verify(!useSystemMemory, "using system memory in spillable aggregations is not supported"); @@ -667,11 +667,4 @@ private static long calculateDefaultOutputHash(List groupByChannels, int g } return result; } - - public enum ReserveType - { - USER, - SYSTEM, - REVOCABLE - } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index b66343ae40bdd..6f5d59843ecd1 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -19,7 +19,6 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.operator.GroupByHash; -import com.facebook.presto.operator.HashAggregationOperator.ReserveType; import com.facebook.presto.operator.HashCollisionsCounter; import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.operator.TransformWork; @@ -45,7 +44,6 @@ import java.util.List; import java.util.Optional; import java.util.OptionalLong; -import java.util.function.Consumer; import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -63,8 +61,7 @@ public class InMemoryHashAggregationBuilder private final OptionalLong maxPartialMemory; private final LocalMemoryContext systemMemoryContext; private final LocalMemoryContext localUserMemoryContext; - private final ReserveType reserveType; - private final Consumer memoryConsumer; + private final boolean useSystemMemory; private boolean full; @@ -79,7 +76,7 @@ public InMemoryHashAggregationBuilder( Optional maxPartialMemory, JoinCompiler joinCompiler, boolean yieldForMemoryReservation, - ReserveType reserveType) + boolean useSystemMemory) { this(accumulatorFactories, step, @@ -92,36 +89,7 @@ public InMemoryHashAggregationBuilder( Optional.empty(), joinCompiler, yieldForMemoryReservation, - reserveType, - Optional.empty()); - } - - public InMemoryHashAggregationBuilder( - List accumulatorFactories, - Step step, - int expectedGroups, - List groupByTypes, - List groupByChannels, - Optional hashChannel, - OperatorContext operatorContext, - Optional maxPartialMemory, - JoinCompiler joinCompiler, - boolean yieldForMemoryReservation, - Optional> memoryConsumer) - { - this(accumulatorFactories, - step, - expectedGroups, - groupByTypes, - groupByChannels, - hashChannel, - operatorContext, - maxPartialMemory, - Optional.empty(), - joinCompiler, - yieldForMemoryReservation, - ReserveType.REVOCABLE, - memoryConsumer); + useSystemMemory); } public InMemoryHashAggregationBuilder( @@ -136,24 +104,8 @@ public InMemoryHashAggregationBuilder( Optional overwriteIntermediateChannelOffset, JoinCompiler joinCompiler, boolean yieldForMemoryReservation, - ReserveType reserveType, - Optional> memoryConsumer) + boolean useSystemMemory) { - // reserveType is REVOCABLE implies current InMemoryHashAggregationBuilder is built from SpillableHashAggregationBuilder - // and it will accept a customized memoryConsumer for memory update - if (reserveType == ReserveType.REVOCABLE) { - checkArgument(memoryConsumer.isPresent(), - "memoryConsumer must be present when reserve type is REVOCABLE"); - } - - this.reserveType = reserveType; - if (memoryConsumer.isPresent()) { - this.memoryConsumer = memoryConsumer.get(); - } - else { - this.memoryConsumer = this::updateMemory; - } - UpdateMemory updateMemory; if (yieldForMemoryReservation) { updateMemory = this::updateMemoryWithYieldInfo; @@ -161,6 +113,7 @@ public InMemoryHashAggregationBuilder( else { // Report memory usage but do not yield for memory. // This is specially used for spillable hash aggregation operator. + // TODO: revisit this when spillable hash aggregation operator is turned on updateMemory = () -> { updateMemoryWithYieldInfo(); return true; @@ -179,6 +132,7 @@ public InMemoryHashAggregationBuilder( this.maxPartialMemory = maxPartialMemory.map(dataSize -> OptionalLong.of(dataSize.toBytes())).orElseGet(OptionalLong::empty); this.systemMemoryContext = operatorContext.newLocalSystemMemoryContext(InMemoryHashAggregationBuilder.class.getSimpleName()); this.localUserMemoryContext = operatorContext.localUserMemoryContext(); + this.useSystemMemory = useSystemMemory; // wrapper each function with an aggregator ImmutableList.Builder builder = ImmutableList.builder(); @@ -197,7 +151,7 @@ public InMemoryHashAggregationBuilder( @Override public void close() { - memoryConsumer.accept(0L); + updateMemory(0); } @Override @@ -372,28 +326,24 @@ private boolean updateMemoryWithYieldInfo() { long memorySize = getSizeInMemory(); if (partial && maxPartialMemory.isPresent()) { - memoryConsumer.accept(memorySize); + updateMemory(memorySize); full = (memorySize > maxPartialMemory.getAsLong()); return true; } // Operator/driver will be blocked on memory after we call setBytes. // If memory is not available, once we return, this operator will be blocked until memory is available. - memoryConsumer.accept(memorySize); + updateMemory(memorySize); // If memory is not available, inform the caller that we cannot proceed for allocation. return operatorContext.isWaitingForMemory().isDone(); } private void updateMemory(long memorySize) { - switch (reserveType) { - case USER: - localUserMemoryContext.setBytes(memorySize); - break; - case SYSTEM: - systemMemoryContext.setBytes(memorySize); - break; - default: - throw new AssertionError("InMemoryHashAggregationBuilder do not support reserve type: " + reserveType); + if (useSystemMemory) { + systemMemoryContext.setBytes(memorySize); + } + else { + localUserMemoryContext.setBytes(memorySize); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java index 302692c0a9525..67793140d53c5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java @@ -16,7 +16,6 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; -import com.facebook.presto.operator.HashAggregationOperator.ReserveType; import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.operator.WorkProcessor; import com.facebook.presto.operator.WorkProcessor.Transformation; @@ -151,7 +150,6 @@ private void rebuildHashAggregationBuilder() Optional.of(overwriteIntermediateChannelOffset), joinCompiler, false, - ReserveType.USER, - Optional.empty()); + false); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index eb7c2bc456986..2f36153c25e57 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -335,7 +335,7 @@ private void rebuildHashAggregationBuilder() Optional.of(DataSize.succinctBytes(0)), joinCompiler, false, - Optional.of((memorySize) -> localRevocableMemoryContext.setBytes(memorySize))); + false); emptyHashAggregationBuilderSize = hashAggregationBuilder.getSizeInMemory(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java index edbfc8486d93f..41830d6ce937b 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingTaskContext.java @@ -55,15 +55,6 @@ public static TaskContext createTaskContext(Executor notificationExecutor, Sched .build(); } - public static TaskContext createTaskContext(Executor notificationExecutor, ScheduledExecutorService yieldExecutor, Session session, - DataSize maxMemory, DataSize maxTotalMemory) - { - return builder(notificationExecutor, yieldExecutor, session) - .setQueryMaxMemory(maxMemory) - .setQueryMaxTotalMemory(maxTotalMemory) - .build(); - } - public static TaskContext createTaskContext(Executor notificationExecutor, ScheduledExecutorService yieldExecutor, Session session, TaskStateMachine taskStateMachine) { return builder(notificationExecutor, yieldExecutor, session) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index 8ffe36a4c3d11..9b17a4377836a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -620,54 +620,6 @@ public void testMergeWithMemorySpill() assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, resultBuilder.build()); } - @Test - public void testMemoryLimitInSpillWhenTriggerRehash() - { - RowPagesBuilder rowPagesBuilder = rowPagesBuilder(BIGINT); - - int smallPagesSpillThresholdSize = 100000; - - List input = rowPagesBuilder - .addSequencePage(smallPagesSpillThresholdSize, 0) - .addSequencePage(smallPagesSpillThresholdSize, smallPagesSpillThresholdSize) - .addSequencePage(smallPagesSpillThresholdSize, 2 * smallPagesSpillThresholdSize) - .addSequencePage(smallPagesSpillThresholdSize, 3 * smallPagesSpillThresholdSize) - .build(); - - HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( - 0, - new PlanNodeId("test"), - ImmutableList.of(BIGINT), - ImmutableList.of(0), - ImmutableList.of(), - ImmutableList.of(), - Step.SINGLE, - false, - ImmutableList.of(generateAccumulatorFactory(LONG_SUM, ImmutableList.of(0), Optional.empty())), - rowPagesBuilder.getHashChannel(), - Optional.empty(), - 1, - Optional.of(new DataSize(16, MEGABYTE)), - true, - new DataSize(smallPagesSpillThresholdSize, Unit.BYTE), - succinctBytes(Integer.MAX_VALUE), - spillerFactory, - joinCompiler, - false); - - TaskContext taskContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION, - new DataSize(10, MEGABYTE), new DataSize(20, MEGABYTE)); - DriverContext driverContext = taskContext - .addPipelineContext(0, true, true, false) - .addDriverContext(); - - MaterializedResult.Builder resultBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT); - for (int i = 0; i < 4 * smallPagesSpillThresholdSize; ++i) { - resultBuilder.row((long) i, (long) i); - } - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, resultBuilder.build()); - } - @Test public void testSpillerFailure() {