diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/builder/SpillableHashAggregationBuilder.java index 3dc0a583fa1a..0ecc565e38a4 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -36,9 +36,11 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.prestosql.operator.Operator.NOT_BLOCKED; import static java.lang.Math.max; public class SpillableHashAggregationBuilder @@ -68,6 +70,7 @@ public class SpillableHashAggregationBuilder private long hashCollisions; private double expectedHashCollisions; + private boolean producingOutput; public SpillableHashAggregationBuilder( List accumulatorFactories, @@ -112,8 +115,15 @@ public Work processPage(Page page) public void updateMemory() { checkState(spillInProgress.isDone()); - localUserMemoryContext.setBytes(emptyHashAggregationBuilderSize); - localRevocableMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory() - emptyHashAggregationBuilderSize); + + if (producingOutput) { + localRevocableMemoryContext.setBytes(0); + localUserMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory()); + } + else { + localUserMemoryContext.setBytes(emptyHashAggregationBuilderSize); + localRevocableMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory() - emptyHashAggregationBuilderSize); + } } public long getSizeInMemory() @@ -154,9 +164,13 @@ private boolean hasPreviousSpillCompletedSuccessfully() @Override public ListenableFuture startMemoryRevoke() { - checkState(spillInProgress.isDone()); - spillToDisk(); - return spillInProgress; + if (producingOutput) { + // all revocable memory has been released in buildResult method + verify(localRevocableMemoryContext.getBytes() == 0); + return NOT_BLOCKED; + } + + return spillToDisk(); } @Override @@ -174,6 +188,22 @@ private boolean shouldMergeWithMemory(long memorySize) public WorkProcessor buildResult() { checkState(hasPreviousSpillCompletedSuccessfully(), "Previous spill hasn't yet finished"); + producingOutput = true; + + // Convert revocable memory to user memory as returned WorkProcessor holds on to memory so we no longer can revoke. + if (localRevocableMemoryContext.getBytes() > 0) { + long currentRevocableBytes = localRevocableMemoryContext.getBytes(); + localRevocableMemoryContext.setBytes(0); + if (!localUserMemoryContext.trySetBytes(localUserMemoryContext.getBytes() + currentRevocableBytes)) { + // TODO: this might fail (even though we have just released memory), but we don't + // have a proper way to atomically convert memory reservations + localRevocableMemoryContext.setBytes(currentRevocableBytes); + // spill since revocable memory could not be converted to user memory immediately + // TODO: this should be asynchronous + getFutureValue(spillToDisk()); + updateMemory(); + } + } if (!spiller.isPresent()) { return hashAggregationBuilder.buildResult(); diff --git a/presto-main/src/test/java/io/prestosql/operator/OperatorAssertion.java b/presto-main/src/test/java/io/prestosql/operator/OperatorAssertion.java index d9905a94b61f..41119314c5fd 100644 --- a/presto-main/src/test/java/io/prestosql/operator/OperatorAssertion.java +++ b/presto-main/src/test/java/io/prestosql/operator/OperatorAssertion.java @@ -67,7 +67,20 @@ public static List toPages(Operator operator, Iterator input) .build(); } + public static List toPages(Operator operator, Iterator input, boolean revokeMemoryWhenAddingPages) + { + return ImmutableList.builder() + .addAll(toPagesPartial(operator, input, revokeMemoryWhenAddingPages)) + .addAll(finishOperator(operator)) + .build(); + } + public static List toPagesPartial(Operator operator, Iterator input) + { + return toPagesPartial(operator, input, true); + } + + public static List toPagesPartial(Operator operator, Iterator input, boolean revokeMemory) { // verify initial state assertEquals(operator.isFinished(), false); @@ -77,7 +90,10 @@ public static List toPagesPartial(Operator operator, Iterator input) if (handledBlocked(operator)) { continue; } - handleMemoryRevoking(operator); + + if (revokeMemory) { + handleMemoryRevoking(operator); + } if (input.hasNext() && operator.needsInput()) { operator.addInput(input.next()); @@ -102,13 +118,16 @@ public static List finishOperator(Operator operator) if (handledBlocked(operator)) { continue; } - handleMemoryRevoking(operator); + operator.finish(); Page outputPage = operator.getOutput(); if (outputPage != null && outputPage.getPositionCount() != 0) { outputPages.add(outputPage); loopsSinceLastPage = 0; } + + // revoke memory when output pages have started being produced + handleMemoryRevoking(operator); } assertEquals(operator.isFinished(), true, "Operator did not finish"); @@ -137,9 +156,14 @@ private static void handleMemoryRevoking(Operator operator) } public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext, List input) + { + return toPages(operatorFactory, driverContext, input, true); + } + + public static List toPages(OperatorFactory operatorFactory, DriverContext driverContext, List input, boolean revokeMemoryWhenAddingPages) { try (Operator operator = operatorFactory.createOperator(driverContext)) { - return toPages(operator, input.iterator()); + return toPages(operator, input.iterator(), revokeMemoryWhenAddingPages); } catch (Exception e) { throwIfUnchecked(e); @@ -218,12 +242,38 @@ public static void assertOperatorEqualsIgnoreOrder( boolean hashEnabled, Optional hashChannel) { - List pages = toPages(operatorFactory, driverContext, input); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, hashChannel, true); + } + + public static void assertOperatorEqualsIgnoreOrder( + OperatorFactory operatorFactory, + DriverContext driverContext, + List input, + MaterializedResult expected, + boolean hashEnabled, + Optional hashChannel, + boolean revokeMemoryWhenAddingPages) + { + assertPagesEqualIgnoreOrder( + driverContext, + toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages), + expected, + hashEnabled, + hashChannel); + } + + public static void assertPagesEqualIgnoreOrder( + DriverContext driverContext, + List actualPages, + MaterializedResult expected, + boolean hashEnabled, + Optional hashChannel) + { if (hashEnabled && hashChannel.isPresent()) { // Drop the hashChannel for all pages - pages = dropChannel(pages, ImmutableList.of(hashChannel.get())); + actualPages = dropChannel(actualPages, ImmutableList.of(hashChannel.get())); } - MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), pages); + MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), actualPages); assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows()); } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java index ccfa7abdbfb2..e628fb08e077 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java @@ -74,6 +74,7 @@ import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; import static io.prestosql.operator.GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash; import static io.prestosql.operator.OperatorAssertion.assertOperatorEqualsIgnoreOrder; +import static io.prestosql.operator.OperatorAssertion.assertPagesEqualIgnoreOrder; import static io.prestosql.operator.OperatorAssertion.dropChannel; import static io.prestosql.operator.OperatorAssertion.toMaterializedResult; import static io.prestosql.operator.OperatorAssertion.toPages; @@ -130,11 +131,15 @@ public static Object[][] hashEnabled() public static Object[][] hashEnabledAndMemoryLimitForMergeValuesProvider() { return new Object[][] { - {true, true, 8, Integer.MAX_VALUE}, - {false, false, 0, 0}, - {false, true, 0, 0}, - {false, true, 8, 0}, - {false, true, 8, Integer.MAX_VALUE}}; + {true, true, true, 8, Integer.MAX_VALUE}, + {true, true, false, 8, Integer.MAX_VALUE}, + {false, false, false, 0, 0}, + {false, true, true, 0, 0}, + {false, true, false, 0, 0}, + {false, true, true, 8, 0}, + {false, true, false, 8, 0}, + {false, true, true, 8, Integer.MAX_VALUE}, + {false, true, false, 8, Integer.MAX_VALUE}}; } @DataProvider @@ -152,8 +157,10 @@ public void tearDown() } @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { + // make operator produce multiple pages during finish phase + int numberOfRows = 40_000; MetadataManager metadata = MetadataManager.createTestMetadataManager(); InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation( new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR))); @@ -164,9 +171,9 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long List hashChannels = Ints.asList(1); RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, VARCHAR, VARCHAR, VARCHAR, BIGINT, BOOLEAN); List input = rowPagesBuilder - .addSequencePage(10, 100, 0, 100, 0, 500) - .addSequencePage(10, 100, 0, 200, 0, 500) - .addSequencePage(10, 100, 0, 300, 0, 500) + .addSequencePage(numberOfRows, 100, 0, 100_000, 0, 500) + .addSequencePage(numberOfRows, 100, 0, 200_000, 0, 500) + .addSequencePage(numberOfRows, 100, 0, 300_000, 0, 500) .build(); HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory( @@ -196,25 +203,21 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long DriverContext driverContext = createDriverContext(memoryLimitForMerge); - MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT) - .row("0", 3L, 0L, 0.0, "300", 3L, 3L) - .row("1", 3L, 3L, 1.0, "301", 3L, 3L) - .row("2", 3L, 6L, 2.0, "302", 3L, 3L) - .row("3", 3L, 9L, 3.0, "303", 3L, 3L) - .row("4", 3L, 12L, 4.0, "304", 3L, 3L) - .row("5", 3L, 15L, 5.0, "305", 3L, 3L) - .row("6", 3L, 18L, 6.0, "306", 3L, 3L) - .row("7", 3L, 21L, 7.0, "307", 3L, 3L) - .row("8", 3L, 24L, 8.0, "308", 3L, 3L) - .row("9", 3L, 27L, 9.0, "309", 3L, 3L) - .build(); + MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT); + for (int i = 0; i < numberOfRows; ++i) { + expectedBuilder.row(Integer.toString(i), 3L, 3L * i, (double) i, Integer.toString(300_000 + i), 3L, 3L); + } + MaterializedResult expected = expectedBuilder.build(); + + List pages = toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages); + assertGreaterThan(pages.size(), 1, "Expected more than one output page"); + assertPagesEqualIgnoreOrder(driverContext, pages, expected, hashEnabled, Optional.of(hashChannels.size())); - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(hashChannels.size())); assertTrue(spillEnabled == (spillerFactory.getSpillsCount() > 0), format("Spill state mismatch. Expected spill: %s, spill count: %s", spillEnabled, spillerFactory.getSpillsCount())); } @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { MetadataManager metadata = MetadataManager.createTestMetadataManager(); InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation( @@ -261,11 +264,11 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna .row(null, 49L, 0L, null, null, null, 0L, 0L) .build(); - assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size())); + assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()), revokeMemoryWhenAddingPages); } @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { MetadataManager metadata = MetadataManager.createTestMetadataManager(); InternalAggregationFunction arrayAggColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation( @@ -304,7 +307,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp false); Operator operator = operatorFactory.createOperator(driverContext); - toPages(operator, input.iterator()); + toPages(operator, input.iterator(), revokeMemoryWhenAddingPages); assertEquals(operator.getOperatorContext().getOperatorStats().getUserMemoryReservation().toBytes(), 0); } @@ -349,7 +352,7 @@ public void testMemoryLimit(boolean hashEnabled) } @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") - public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) + public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { BlockBuilder builder = VARCHAR.createBlockBuilder(null, 1, MAX_BLOCK_SIZE_IN_BYTES); VARCHAR.writeSlice(builder, Slices.allocate(200_000)); // this must be larger than MAX_BLOCK_SIZE_IN_BYTES, 64K @@ -385,7 +388,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, lon joinCompiler, false); - toPages(operatorFactory, driverContext, input); + toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages); } @Test(dataProvider = "dataType")