diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java index 2d5dee1e441c2..bdb44b3a7eae7 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java @@ -124,7 +124,7 @@ protected List createOperatorFactories() Optional.empty(), Optional.empty(), 10_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), JOIN_COMPILER); return ImmutableList.of(tableScanOperator, tpchQuery1Operator, aggregationOperator); diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java index 67eb792fb6047..aa4bf946b523e 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java @@ -62,7 +62,7 @@ protected List createOperatorFactories() Optional.empty(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), JOIN_COMPILER); return ImmutableList.of(tableScanOperator, aggregationOperator); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index cca583508f5af..264fac10278b5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -63,7 +63,7 @@ public static class HashAggregationOperatorFactory private final Optional groupIdChannel; private final int expectedGroups; - private final DataSize maxPartialMemory; + private final Optional maxPartialMemory; private final boolean spillEnabled; private final DataSize memoryLimitForMerge; private final DataSize memoryLimitForMergeWithMemory; @@ -84,7 +84,7 @@ public HashAggregationOperatorFactory( Optional hashChannel, Optional groupIdChannel, int expectedGroups, - DataSize maxPartialMemory, + Optional maxPartialMemory, JoinCompiler joinCompiler) { this(operatorId, @@ -120,7 +120,7 @@ public HashAggregationOperatorFactory( Optional hashChannel, Optional groupIdChannel, int expectedGroups, - DataSize maxPartialMemory, + Optional maxPartialMemory, boolean spillEnabled, DataSize unspillMemoryLimit, SpillerFactory spillerFactory, @@ -158,7 +158,7 @@ public HashAggregationOperatorFactory( Optional hashChannel, Optional groupIdChannel, int expectedGroups, - DataSize maxPartialMemory, + Optional maxPartialMemory, boolean spillEnabled, DataSize memoryLimitForMerge, DataSize memoryLimitForMergeWithMemory, @@ -250,7 +250,7 @@ public OperatorFactory duplicate() private final Optional hashChannel; private final Optional groupIdChannel; private final int expectedGroups; - private final DataSize maxPartialMemory; + private final Optional maxPartialMemory; private final boolean spillEnabled; private final DataSize memoryLimitForMerge; private final DataSize memoryLimitForMergeWithMemory; @@ -280,7 +280,7 @@ public HashAggregationOperator( Optional hashChannel, Optional groupIdChannel, int expectedGroups, - DataSize maxPartialMemory, + Optional maxPartialMemory, boolean spillEnabled, DataSize memoryLimitForMerge, DataSize memoryLimitForMergeWithMemory, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index cb521940f192f..5f0069636d7ce 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -45,6 +45,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.OptionalLong; import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled; import static com.facebook.presto.operator.GroupByHash.createGroupByHash; @@ -59,7 +60,7 @@ public class InMemoryHashAggregationBuilder private final List aggregators; private final OperatorContext operatorContext; private final boolean partial; - private final long maxPartialMemory; + private final OptionalLong maxPartialMemory; private final LocalMemoryContext systemMemoryContext; private final LocalMemoryContext localUserMemoryContext; @@ -73,7 +74,7 @@ public InMemoryHashAggregationBuilder( List groupByChannels, Optional hashChannel, OperatorContext operatorContext, - DataSize maxPartialMemory, + Optional maxPartialMemory, JoinCompiler joinCompiler, boolean yieldForMemoryReservation) { @@ -98,7 +99,7 @@ public InMemoryHashAggregationBuilder( List groupByChannels, Optional hashChannel, OperatorContext operatorContext, - DataSize maxPartialMemory, + Optional maxPartialMemory, Optional overwriteIntermediateChannelOffset, JoinCompiler joinCompiler, boolean yieldForMemoryReservation) @@ -126,7 +127,7 @@ public InMemoryHashAggregationBuilder( updateMemory); this.operatorContext = operatorContext; this.partial = step.isOutputPartial(); - this.maxPartialMemory = maxPartialMemory.toBytes(); + this.maxPartialMemory = maxPartialMemory.map(dataSize -> OptionalLong.of(dataSize.toBytes())).orElseGet(OptionalLong::empty); this.systemMemoryContext = operatorContext.newLocalSystemMemoryContext(InMemoryHashAggregationBuilder.class.getSimpleName()); this.localUserMemoryContext = operatorContext.localUserMemoryContext(); @@ -326,9 +327,10 @@ public List buildTypes() private boolean updateMemoryWithYieldInfo() { long memorySize = getSizeInMemory(); - if (partial) { + // if partial limit is not set, memory is considered as user memory + if (partial && maxPartialMemory.isPresent()) { systemMemoryContext.setBytes(memorySize); - full = (memorySize > maxPartialMemory); + full = (memorySize > maxPartialMemory.getAsLong()); return true; } // Operator/driver will be blocked on memory after we call setBytes. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java index ebe705845f4f2..2a0f751c6cdd5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java @@ -144,7 +144,7 @@ private void rebuildHashAggregationBuilder() groupByPartialChannels, hashChannel, operatorContext, - DataSize.succinctBytes(0), + Optional.of(DataSize.succinctBytes(0)), Optional.of(overwriteIntermediateChannelOffset), joinCompiler, false); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index 2e24b715bade3..46d1e71a19159 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -299,7 +299,7 @@ private void rebuildHashAggregationBuilder() groupByChannels, hashChannel, operatorContext, - DataSize.succinctBytes(0), + Optional.of(DataSize.succinctBytes(0)), joinCompiler, false); emptyHashAggregationBuilderSize = hashAggregationBuilder.getSizeInMemory(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 2d1b4d81e248a..3d9b5b69ce179 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -2173,7 +2173,14 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl new DataSize(0, BYTE), context, 2, - outputMapping); + outputMapping, + 200, + // This aggregation must behave as INTERMEDIATE. + // Using INTERMEDIATE aggregation directly + // is not possible, as it doesn't accept raw input data. + // Disabling partial pre-aggregation memory limit effectively + // turns PARTIAL aggregation into INTERMEDIATE. + Optional.empty()); }).orElse(new DevNullOperatorFactory(context.getNextOperatorId(), node.getId())); List inputChannels = node.getColumns().stream() @@ -2227,7 +2234,10 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl new DataSize(0, BYTE), context, 0, - outputMapping); + outputMapping, + 200, + // final aggregation ignores partial pre-aggregation memory limit + Optional.empty()); }).orElse(new DevNullOperatorFactory(context.getNextOperatorId(), node.getId())); Map aggregationOutput = outputMapping.build(); @@ -2544,7 +2554,9 @@ private PhysicalOperation planGroupByAggregation( unspillMemoryLimit, context, 0, - mappings); + mappings, + 10_000, + Optional.of(maxPartialAggregationMemorySize)); return new PhysicalOperation(operatorFactory, mappings.build(), context, source); } @@ -2563,7 +2575,9 @@ private OperatorFactory createHashAggregationOperatorFactory( DataSize unspillMemoryLimit, LocalExecutionPlanContext context, int startOutputChannel, - ImmutableMap.Builder outputMappings) + ImmutableMap.Builder outputMappings, + int expectedGroups, + Optional maxPartialAggregationMemorySize) { List aggregationOutputSymbols = new ArrayList<>(); List accumulatorFactories = new ArrayList<>(); @@ -2626,7 +2640,7 @@ private OperatorFactory createHashAggregationOperatorFactory( accumulatorFactories, hashChannel, groupIdChannel, - 10_000, + expectedGroups, maxPartialAggregationMemorySize, spillEnabled, unspillMemoryLimit, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java index 8228819b55431..fb530c19fdeaf 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -165,7 +165,7 @@ private OperatorFactory createHashAggregationOperatorFactory(Optional h hashChannel, Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), false, succinctBytes(8), succinctBytes(Integer.MAX_VALUE), diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index d03ae1abc2abb..b27dca365ca7b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -183,7 +183,7 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), @@ -243,7 +243,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna rowPagesBuilder.getHashChannel(), groupIdChannel, 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), @@ -290,7 +290,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), @@ -335,7 +335,7 @@ public void testMemoryLimit(boolean hashEnabled) rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), joinCompiler); toPages(operatorFactory, driverContext, input); @@ -370,7 +370,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, lon rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), spillEnabled, succinctBytes(memoryLimitForMerge), succinctBytes(memoryLimitForMergeWithMemory), @@ -395,7 +395,7 @@ public void testMemoryReservationYield(Type type) Optional.of(1), Optional.empty(), 1, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), joinCompiler); // get result with yield; pick a relatively small buffer for aggregator's memory usage @@ -446,7 +446,7 @@ public void testHashBuilderResizeLimit(boolean hashEnabled) rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), joinCompiler); toPages(operatorFactory, driverContext, input); @@ -479,7 +479,7 @@ public void testMultiSliceAggregationOutput(boolean hashEnabled) rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), joinCompiler); assertEquals(toPages(operatorFactory, createDriverContext(), input).size(), 2); @@ -509,7 +509,7 @@ public void testMultiplePartialFlushes(boolean hashEnabled) rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(1, KILOBYTE), + Optional.of(new DataSize(1, KILOBYTE)), joinCompiler); DriverContext driverContext = createDriverContext(1024); @@ -584,7 +584,7 @@ public void testMergeWithMemorySpill() rowPagesBuilder.getHashChannel(), Optional.empty(), 1, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), true, new DataSize(smallPagesSpillThresholdSize, Unit.BYTE), succinctBytes(Integer.MAX_VALUE), @@ -639,7 +639,7 @@ public void testSpillerFailure() rowPagesBuilder.getHashChannel(), Optional.empty(), 100_000, - new DataSize(16, MEGABYTE), + Optional.of(new DataSize(16, MEGABYTE)), true, succinctBytes(8), succinctBytes(Integer.MAX_VALUE),