diff --git a/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java b/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java index 48e104e2daad..cc90bea7d20c 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/QueryContext.java @@ -459,20 +459,20 @@ public QueryMemoryReservationHandler( @Override public ListenableFuture reserveMemory(String allocationTag, long delta, boolean enforceBroadcastMemoryLimit) { - ListenableFuture future = reserveMemoryFunction.apply(allocationTag, delta); if (enforceBroadcastMemoryLimit) { updateBroadcastMemoryFunction.accept(delta); } + ListenableFuture future = reserveMemoryFunction.apply(allocationTag, delta); return future; } @Override public boolean tryReserveMemory(String allocationTag, long delta, boolean enforceBroadcastMemoryLimit) { - if (!tryReserveMemoryFunction.test(allocationTag, delta)) { + if (enforceBroadcastMemoryLimit && !tryUpdateBroadcastMemoryFunction.test(delta)) { return false; } - return !enforceBroadcastMemoryLimit || tryUpdateBroadcastMemoryFunction.test(delta); + return tryReserveMemoryFunction.test(allocationTag, delta); } } diff --git a/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java b/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java index f373b033143a..f576862fc61f 100644 --- a/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java +++ b/presto-main/src/test/java/com/facebook/presto/memory/TestQueryContext.java @@ -138,6 +138,39 @@ public void testChecksTotalMemoryOnUserMemoryAllocation() } } + @Test + public void testChecksTotalMemoryOnUserMemoryAllocationWithBroadcastEnable() + { + MemoryPool generalPool = new MemoryPool(GENERAL_POOL, new DataSize(10, BYTE)); + try (LocalQueryRunner localQueryRunner = new LocalQueryRunner(TEST_SESSION)) { + QueryContext queryContext = new QueryContext( + new QueryId("query"), + new DataSize(10, BYTE), // user memory limit + new DataSize(20, BYTE), // total memory limit + new DataSize(10, BYTE), + new DataSize(1, GIGABYTE), + generalPool, + new TestingGcMonitor(), + localQueryRunner.getExecutor(), + localQueryRunner.getScheduler(), + new DataSize(0, BYTE), + new SpillSpaceTracker(new DataSize(0, BYTE)), + listJsonCodec(TaskMemoryReservationSummary.class)); + + queryContext.getQueryMemoryContext().initializeLocalMemoryContexts("test"); + LocalMemoryContext systemMemoryContext = queryContext.getQueryMemoryContext().localSystemMemoryContext(); + LocalMemoryContext userMemoryContext = queryContext.getQueryMemoryContext().localUserMemoryContext(); + try { + systemMemoryContext.setBytes(15, true); + userMemoryContext.setBytes(6); + } + catch (ExceededMemoryLimitException e) { + assertTrue(e.getMessage().contains("Query exceeded per-node broadcast memory limit of 10B")); + assertEquals(generalPool.getReservedBytes(), 0); + } + } + } + @Test public void testMoveTaggedAllocations() {