Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.prestosql.operator.Operator.NOT_BLOCKED;
import static java.lang.Math.max;

public class SpillableHashAggregationBuilder
Expand Down Expand Up @@ -68,6 +70,7 @@ public class SpillableHashAggregationBuilder

private long hashCollisions;
private double expectedHashCollisions;
private boolean producingOutput;

public SpillableHashAggregationBuilder(
List<AccumulatorFactory> accumulatorFactories,
Expand Down Expand Up @@ -112,8 +115,15 @@ public Work<?> processPage(Page page)
public void updateMemory()
{
checkState(spillInProgress.isDone());
localUserMemoryContext.setBytes(emptyHashAggregationBuilderSize);
localRevocableMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory() - emptyHashAggregationBuilderSize);

if (producingOutput) {
localRevocableMemoryContext.setBytes(0);
localUserMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory());
}
else {
localUserMemoryContext.setBytes(emptyHashAggregationBuilderSize);
localRevocableMemoryContext.setBytes(hashAggregationBuilder.getSizeInMemory() - emptyHashAggregationBuilderSize);
}
}

public long getSizeInMemory()
Expand Down Expand Up @@ -154,9 +164,13 @@ private boolean hasPreviousSpillCompletedSuccessfully()
@Override
public ListenableFuture<?> startMemoryRevoke()
{
checkState(spillInProgress.isDone());
spillToDisk();
return spillInProgress;
if (producingOutput) {
// all revocable memory has been released in buildResult method
verify(localRevocableMemoryContext.getBytes() == 0);
return NOT_BLOCKED;
}

return spillToDisk();
}

@Override
Expand All @@ -174,6 +188,22 @@ private boolean shouldMergeWithMemory(long memorySize)
public WorkProcessor<Page> buildResult()
{
checkState(hasPreviousSpillCompletedSuccessfully(), "Previous spill hasn't yet finished");
producingOutput = true;

// Convert revocable memory to user memory as returned WorkProcessor holds on to memory so we no longer can revoke.
if (localRevocableMemoryContext.getBytes() > 0) {
long currentRevocableBytes = localRevocableMemoryContext.getBytes();
localRevocableMemoryContext.setBytes(0);
if (!localUserMemoryContext.trySetBytes(localUserMemoryContext.getBytes() + currentRevocableBytes)) {
// TODO: this might fail (even though we have just released memory), but we don't
// have a proper way to atomically convert memory reservations
localRevocableMemoryContext.setBytes(currentRevocableBytes);
// spill since revocable memory could not be converted to user memory immediately
// TODO: this should be asynchronous
getFutureValue(spillToDisk());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can take time.
Can we spill asynchronously?

spillToDisk();
return WorkProcessor.empty();

(then we need to updateMemory(); on next call..)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added todo since this requires bigger operator refactor

updateMemory();
}
}

if (!spiller.isPresent()) {
return hashAggregationBuilder.buildResult();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,20 @@ public static List<Page> toPages(Operator operator, Iterator<Page> input)
.build();
}

public static List<Page> toPages(Operator operator, Iterator<Page> input, boolean revokeMemoryWhenAddingPages)
{
return ImmutableList.<Page>builder()
.addAll(toPagesPartial(operator, input, revokeMemoryWhenAddingPages))
.addAll(finishOperator(operator))
.build();
}

public static List<Page> toPagesPartial(Operator operator, Iterator<Page> input)
{
return toPagesPartial(operator, input, true);
}

public static List<Page> toPagesPartial(Operator operator, Iterator<Page> input, boolean revokeMemory)
{
// verify initial state
assertEquals(operator.isFinished(), false);
Expand All @@ -77,7 +90,10 @@ public static List<Page> toPagesPartial(Operator operator, Iterator<Page> input)
if (handledBlocked(operator)) {
continue;
}
handleMemoryRevoking(operator);

if (revokeMemory) {
handleMemoryRevoking(operator);
}

if (input.hasNext() && operator.needsInput()) {
operator.addInput(input.next());
Expand All @@ -102,13 +118,16 @@ public static List<Page> finishOperator(Operator operator)
if (handledBlocked(operator)) {
continue;
}
handleMemoryRevoking(operator);

operator.finish();
Page outputPage = operator.getOutput();
if (outputPage != null && outputPage.getPositionCount() != 0) {
outputPages.add(outputPage);
loopsSinceLastPage = 0;
}

// revoke memory when output pages have started being produced
handleMemoryRevoking(operator);
}

assertEquals(operator.isFinished(), true, "Operator did not finish");
Expand Down Expand Up @@ -137,9 +156,14 @@ private static void handleMemoryRevoking(Operator operator)
}

public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input)
{
return toPages(operatorFactory, driverContext, input, true);
}

public static List<Page> toPages(OperatorFactory operatorFactory, DriverContext driverContext, List<Page> input, boolean revokeMemoryWhenAddingPages)
{
try (Operator operator = operatorFactory.createOperator(driverContext)) {
return toPages(operator, input.iterator());
return toPages(operator, input.iterator(), revokeMemoryWhenAddingPages);
}
catch (Exception e) {
throwIfUnchecked(e);
Expand Down Expand Up @@ -218,12 +242,38 @@ public static void assertOperatorEqualsIgnoreOrder(
boolean hashEnabled,
Optional<Integer> hashChannel)
{
List<Page> pages = toPages(operatorFactory, driverContext, input);
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, hashChannel, true);
}

public static void assertOperatorEqualsIgnoreOrder(
OperatorFactory operatorFactory,
DriverContext driverContext,
List<Page> input,
MaterializedResult expected,
boolean hashEnabled,
Optional<Integer> hashChannel,
boolean revokeMemoryWhenAddingPages)
{
assertPagesEqualIgnoreOrder(
driverContext,
toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages),
expected,
hashEnabled,
hashChannel);
}

public static void assertPagesEqualIgnoreOrder(
DriverContext driverContext,
List<Page> actualPages,
MaterializedResult expected,
boolean hashEnabled,
Optional<Integer> hashChannel)
{
if (hashEnabled && hashChannel.isPresent()) {
// Drop the hashChannel for all pages
pages = dropChannel(pages, ImmutableList.of(hashChannel.get()));
actualPages = dropChannel(actualPages, ImmutableList.of(hashChannel.get()));
}
MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), pages);
MaterializedResult actual = toMaterializedResult(driverContext.getSession(), expected.getTypes(), actualPages);
assertEqualsIgnoreOrder(actual.getMaterializedRows(), expected.getMaterializedRows());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys;
import static io.prestosql.operator.GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash;
import static io.prestosql.operator.OperatorAssertion.assertOperatorEqualsIgnoreOrder;
import static io.prestosql.operator.OperatorAssertion.assertPagesEqualIgnoreOrder;
import static io.prestosql.operator.OperatorAssertion.dropChannel;
import static io.prestosql.operator.OperatorAssertion.toMaterializedResult;
import static io.prestosql.operator.OperatorAssertion.toPages;
Expand Down Expand Up @@ -130,11 +131,15 @@ public static Object[][] hashEnabled()
public static Object[][] hashEnabledAndMemoryLimitForMergeValuesProvider()
{
return new Object[][] {
{true, true, 8, Integer.MAX_VALUE},
{false, false, 0, 0},
{false, true, 0, 0},
{false, true, 8, 0},
{false, true, 8, Integer.MAX_VALUE}};
{true, true, true, 8, Integer.MAX_VALUE},
{true, true, false, 8, Integer.MAX_VALUE},
{false, false, false, 0, 0},
{false, true, true, 0, 0},
{false, true, false, 0, 0},
{false, true, true, 8, 0},
{false, true, false, 8, 0},
{false, true, true, 8, Integer.MAX_VALUE},
{false, true, false, 8, Integer.MAX_VALUE}};
}

@DataProvider
Expand All @@ -152,8 +157,10 @@ public void tearDown()
}

@Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
{
// make operator produce multiple pages during finish phase
int numberOfRows = 40_000;
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(
new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR)));
Expand All @@ -164,9 +171,9 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long
List<Integer> hashChannels = Ints.asList(1);
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, VARCHAR, VARCHAR, VARCHAR, BIGINT, BOOLEAN);
List<Page> input = rowPagesBuilder
.addSequencePage(10, 100, 0, 100, 0, 500)
.addSequencePage(10, 100, 0, 200, 0, 500)
.addSequencePage(10, 100, 0, 300, 0, 500)
.addSequencePage(numberOfRows, 100, 0, 100_000, 0, 500)
.addSequencePage(numberOfRows, 100, 0, 200_000, 0, 500)
.addSequencePage(numberOfRows, 100, 0, 300_000, 0, 500)
.build();

HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(
Expand Down Expand Up @@ -196,25 +203,21 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long

DriverContext driverContext = createDriverContext(memoryLimitForMerge);

MaterializedResult expected = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT)
.row("0", 3L, 0L, 0.0, "300", 3L, 3L)
.row("1", 3L, 3L, 1.0, "301", 3L, 3L)
.row("2", 3L, 6L, 2.0, "302", 3L, 3L)
.row("3", 3L, 9L, 3.0, "303", 3L, 3L)
.row("4", 3L, 12L, 4.0, "304", 3L, 3L)
.row("5", 3L, 15L, 5.0, "305", 3L, 3L)
.row("6", 3L, 18L, 6.0, "306", 3L, 3L)
.row("7", 3L, 21L, 7.0, "307", 3L, 3L)
.row("8", 3L, 24L, 8.0, "308", 3L, 3L)
.row("9", 3L, 27L, 9.0, "309", 3L, 3L)
.build();
MaterializedResult.Builder expectedBuilder = resultBuilder(driverContext.getSession(), VARCHAR, BIGINT, BIGINT, DOUBLE, VARCHAR, BIGINT, BIGINT);
for (int i = 0; i < numberOfRows; ++i) {
expectedBuilder.row(Integer.toString(i), 3L, 3L * i, (double) i, Integer.toString(300_000 + i), 3L, 3L);
}
MaterializedResult expected = expectedBuilder.build();

List<Page> pages = toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages);
assertGreaterThan(pages.size(), 1, "Expected more than one output page");
assertPagesEqualIgnoreOrder(driverContext, pages, expected, hashEnabled, Optional.of(hashChannels.size()));

assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(hashChannels.size()));
assertTrue(spillEnabled == (spillerFactory.getSpillsCount() > 0), format("Spill state mismatch. Expected spill: %s, spill count: %s", spillEnabled, spillerFactory.getSpillsCount()));
}

@Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
{
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction countVarcharColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(
Expand Down Expand Up @@ -261,11 +264,11 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna
.row(null, 49L, 0L, null, null, null, 0L, 0L)
.build();

assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()));
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, expected, hashEnabled, Optional.of(groupByChannels.size()), revokeMemoryWhenAddingPages);
}

@Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
{
MetadataManager metadata = MetadataManager.createTestMetadataManager();
InternalAggregationFunction arrayAggColumn = metadata.getFunctionRegistry().getAggregateFunctionImplementation(
Expand Down Expand Up @@ -304,7 +307,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp
false);

Operator operator = operatorFactory.createOperator(driverContext);
toPages(operator, input.iterator());
toPages(operator, input.iterator(), revokeMemoryWhenAddingPages);
assertEquals(operator.getOperatorContext().getOperatorStats().getUserMemoryReservation().toBytes(), 0);
}

Expand Down Expand Up @@ -349,7 +352,7 @@ public void testMemoryLimit(boolean hashEnabled)
}

@Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues")
public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, boolean revokeMemoryWhenAddingPages, long memoryLimitForMerge, long memoryLimitForMergeWithMemory)
{
BlockBuilder builder = VARCHAR.createBlockBuilder(null, 1, MAX_BLOCK_SIZE_IN_BYTES);
VARCHAR.writeSlice(builder, Slices.allocate(200_000)); // this must be larger than MAX_BLOCK_SIZE_IN_BYTES, 64K
Expand Down Expand Up @@ -385,7 +388,7 @@ public void testHashBuilderResize(boolean hashEnabled, boolean spillEnabled, lon
joinCompiler,
false);

toPages(operatorFactory, driverContext, input);
toPages(operatorFactory, driverContext, input, revokeMemoryWhenAddingPages);
}

@Test(dataProvider = "dataType")
Expand Down