diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java index 7e05dc848b8c1..7e53665355f6b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java @@ -40,6 +40,8 @@ import java.util.Queue; import static com.facebook.airlift.concurrent.MoreFutures.getDone; +import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalUserMemoryLimit; +import static com.facebook.presto.SystemSessionProperties.getQueryMaxMemoryPerNode; import static com.facebook.presto.operator.SpillingUtils.checkSpillSucceeded; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -366,6 +368,14 @@ private void spillInput(Page page) checkState(spillInProgress.isDone(), "Previous spill still in progress"); checkSpillSucceeded(spillInProgress); spillInProgress = getSpiller().spill(page); + + // check that spilled data can still fit into memory limit as otherwise + // it fails later during unspilling when all spilled pages need to be loaded into memory + long maxUserMemoryBytes = getQueryMaxMemoryPerNode(operatorContext.getSession()).toBytes(); + if (getSpiller().getSpilledPagesInMemorySize() > maxUserMemoryBytes) { + String additionalInfo = format("Spilled: %s, Operator: %s", succinctBytes(getSpiller().getSpilledPagesInMemorySize()), HashBuilderOperator.class.getSimpleName()); + throw exceededLocalUserMemoryLimit(succinctBytes(maxUserMemoryBytes), additionalInfo, false, Optional.empty()); + } } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java index 6d6ad18d0c37e..c9ca060952687 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java @@ -75,6 +75,8 @@ import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY_PER_NODE; +import static com.facebook.presto.SystemSessionProperties.getQueryMaxMemoryPerNode; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; @@ -1207,6 +1209,41 @@ public void testBroadcastMemoryLimit(boolean parallelBuild, boolean buildHashEna buildLookupSource(buildSideSetup); } + @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of.* \\[Spilled:.*") + public void testSpillMemoryLimit() + { + Session session = testSessionBuilder().setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "1000B").build(); + + TaskContext taskContext = TestingTaskContext.createTaskContext(executor, scheduledExecutor, session, getQueryMaxMemoryPerNode(session)); + RowPagesBuilder buildPages = rowPagesBuilder(true, Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) + .addSequencePage(1000, 2000, 3000, 4000); + BuildSideSetup buildSideSetup = setupBuildSide(true, + taskContext, + Ints.asList(0), + buildPages, + Optional.empty(), + true, + SINGLE_STREAM_SPILLER_FACTORY); + instantiateBuildDrivers(buildSideSetup, taskContext); + + JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); + PartitionedLookupSourceFactory lookupSourceFactory = lookupSourceFactoryManager.getJoinBridge(Lifespan.taskWide()); + ListenableFuture lookupSourceProvider = lookupSourceFactory.createLookupSourceProvider(); + List buildDrivers = buildSideSetup.getBuildDrivers(); + + while (!lookupSourceProvider.isDone()) { + for (int i = 0; i < buildDrivers.size(); i++) { + revokeMemory(buildSideSetup.getBuildOperators().get(i)); + buildDrivers.get(i).process(); + } + } + getFutureValue(lookupSourceProvider).close(); + + for (Driver buildDriver : buildDrivers) { + runDriverInThread(executor, buildDriver); + } + } + @Test(dataProvider = "hashJoinTestValues") public void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean probeHashEnabled, boolean buildHashEnabled) {