diff --git a/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java b/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java index 4d076164e5f47..fbc593eda816f 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java @@ -172,6 +172,8 @@ private synchronized ListenableFuture updateUserMemory(String allocationTag, { if (delta >= 0) { enforceUserMemoryLimit(queryMemoryContext.getUserMemory(), delta, maxUserMemory); + long totalMemory = memoryPool.getQueryMemoryReservation(queryId); + enforceTotalMemoryLimit(totalMemory, delta, maxTotalMemory); return memoryPool.reserve(queryId, allocationTag, delta); } memoryPool.free(queryId, allocationTag, -delta); @@ -257,6 +259,11 @@ private synchronized boolean tryUpdateUserMemory(String allocationTag, long delt if (queryMemoryContext.getUserMemory() + delta > maxUserMemory) { return false; } + + long totalMemory = memoryPool.getQueryMemoryReservation(queryId); + if (totalMemory + delta > maxTotalMemory) { + return false; + } return memoryPool.tryReserve(queryId, allocationTag, delta); } diff --git a/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java b/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java index 86c2f4e251035..4ae5a56b217b2 100644 --- a/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java +++ b/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.memory; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskStateMachine; import com.facebook.presto.memory.context.LocalMemoryContext; @@ -106,6 +107,31 @@ public void testSetMemoryPool(boolean useReservedPool) } } + @Test(expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = ".*Query exceeded per-node total memory limit of 20B.*") + public void testChecksTotalMemoryOnUserMemoryAllocation() + { + try (LocalQueryRunner localQueryRunner = new LocalQueryRunner(TEST_SESSION)) { + QueryContext queryContext = new QueryContext( + new QueryId("query"), + new DataSize(10, BYTE), // user memory limit + new DataSize(20, BYTE), // total memory limit + new DataSize(10, BYTE), + new DataSize(1, GIGABYTE), + new MemoryPool(GENERAL_POOL, new DataSize(10, BYTE)), + new TestingGcMonitor(), + localQueryRunner.getExecutor(), + localQueryRunner.getScheduler(), + new DataSize(0, BYTE), + new SpillSpaceTracker(new DataSize(0, BYTE))); + + queryContext.getQueryMemoryContext().initializeLocalMemoryContexts("test"); + LocalMemoryContext systemMemoryContext = queryContext.getQueryMemoryContext().localSystemMemoryContext(); + LocalMemoryContext userMemoryContext = queryContext.getQueryMemoryContext().localUserMemoryContext(); + systemMemoryContext.setBytes(15); + userMemoryContext.setBytes(6); + } + } + @Test public void testMoveTaggedAllocations() { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java b/presto-main/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java index bd74853895154..a13e253d37083 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java @@ -77,10 +77,11 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< List result = new LinkedList<>(); // mock an adjustable memory pool - QueryId queryId = new QueryId("test_query"); + QueryId queryId1 = new QueryId("test_query1"); + QueryId queryId2 = new QueryId("test_query2"); MemoryPool memoryPool = new MemoryPool(new MemoryPoolId("test"), new DataSize(1, GIGABYTE)); QueryContext queryContext = new QueryContext( - queryId, + queryId2, new DataSize(512, MEGABYTE), new DataSize(1024, MEGABYTE), new DataSize(512, MEGABYTE), @@ -106,7 +107,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // saturate the pool with a tiny memory left long reservedMemoryInBytes = memoryPool.getFreeBytes() - additionalMemoryInBytes; - memoryPool.reserve(queryId, "test", reservedMemoryInBytes); + memoryPool.reserve(queryId1, "test", reservedMemoryInBytes); long oldMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage(); int oldCapacity = getHashCapacity.apply(operator); @@ -126,7 +127,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // between rehash and memory used by aggregator if (newMemoryUsage < new DataSize(4, MEGABYTE).toBytes()) { // free the pool for the next iteration - memoryPool.free(queryId, "test", reservedMemoryInBytes); + memoryPool.free(queryId1, "test", reservedMemoryInBytes); // this required in case input is blocked operator.getOutput(); continue; @@ -147,7 +148,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< assertLessThan(actualIncreasedMemory, additionalMemoryInBytes); // free the pool for the next iteration - memoryPool.free(queryId, "test", reservedMemoryInBytes); + memoryPool.free(queryId1, "test", reservedMemoryInBytes); } else { // We failed to finish the page processing i.e. we yielded @@ -174,7 +175,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< assertNull(operator.getOutput()); // Free the pool to unblock - memoryPool.free(queryId, "test", reservedMemoryInBytes); + memoryPool.free(queryId1, "test", reservedMemoryInBytes); // Trigger a process through getOutput() or needsInput() output = operator.getOutput();