diff --git a/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java b/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java index 0853bae18f95..30264d8d842b 100644 --- a/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java +++ b/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java @@ -46,6 +46,12 @@ public static ExceededMemoryLimitException exceededLocalTotalMemoryLimit(DataSiz format("Query exceeded per-node total memory limit of %s [%s]", maxMemory, additionalFailureInfo)); } + public static ExceededMemoryLimitException exceededTaskMemoryLimit(DataSize maxMemory, String additionalFailureInfo) + { + return new ExceededMemoryLimitException(EXCEEDED_LOCAL_MEMORY_LIMIT, + format("Query exceeded per-task total memory limit of %s [%s]", maxMemory, additionalFailureInfo)); + } + private ExceededMemoryLimitException(StandardErrorCode errorCode, String message) { super(errorCode, message); diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 096ae7c60de1..f28b421fc5e8 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -131,6 +131,7 @@ public final class SystemSessionProperties public static final String ENABLE_LARGE_DYNAMIC_FILTERS = "enable_large_dynamic_filters"; public static final String QUERY_MAX_MEMORY_PER_NODE = "query_max_memory_per_node"; public static final String QUERY_MAX_TOTAL_MEMORY_PER_NODE = "query_max_total_memory_per_node"; + public static final String QUERY_MAX_TOTAL_MEMORY_PER_TASK = "query_max_total_memory_per_task"; public static final String IGNORE_DOWNSTREAM_PREFERENCES = "ignore_downstream_preferences"; public static final String FILTERING_SEMI_JOIN_TO_INNER = "rewrite_filtering_semi_join_to_inner_join"; public static final String OPTIMIZE_DUPLICATE_INSENSITIVE_JOINS = "optimize_duplicate_insensitive_joins"; @@ -594,6 +595,11 @@ public SystemSessionProperties( "Maximum amount of total memory a query can use per node", nodeMemoryConfig.getMaxQueryTotalMemoryPerNode(), true), + dataSizeProperty( + QUERY_MAX_TOTAL_MEMORY_PER_TASK, + "Maximum amount of memory a single task can use", + nodeMemoryConfig.getMaxQueryTotalMemoryPerTask().orElse(null), + true), booleanProperty( IGNORE_DOWNSTREAM_PREFERENCES, "Ignore Parent's PreferredProperties in AddExchange optimizer", @@ -1160,6 +1166,11 @@ public static DataSize getQueryMaxTotalMemoryPerNode(Session session) return session.getSystemProperty(QUERY_MAX_TOTAL_MEMORY_PER_NODE, DataSize.class); } + public static Optional getQueryMaxTotalMemoryPerTask(Session session) + { + return Optional.ofNullable(session.getSystemProperty(QUERY_MAX_TOTAL_MEMORY_PER_TASK, DataSize.class)); + } + public static boolean ignoreDownStreamPreferences(Session session) { return session.getSystemProperty(IGNORE_DOWNSTREAM_PREFERENCES, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index cfc69152e129..4efb6d0be026 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -75,6 +75,7 @@ import static io.airlift.concurrent.Threads.threadsNamed; import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode; import static io.trino.SystemSessionProperties.getQueryMaxTotalMemoryPerNode; +import static io.trino.SystemSessionProperties.getQueryMaxTotalMemoryPerTask; import static io.trino.SystemSessionProperties.resourceOvercommit; import static io.trino.execution.SqlTask.createSqlTask; import static io.trino.memory.LocalMemoryManager.GENERAL_POOL; @@ -111,6 +112,7 @@ public class SqlTaskManager private final long queryMaxMemoryPerNode; private final long queryMaxTotalMemoryPerNode; + private final Optional queryMaxMemoryPerTask; @GuardedBy("this") private long currentMemoryPoolAssignmentVersion; @@ -155,13 +157,14 @@ public SqlTaskManager( this.localMemoryManager = requireNonNull(localMemoryManager, "localMemoryManager is null"); DataSize maxQueryMemoryPerNode = nodeMemoryConfig.getMaxQueryMemoryPerNode(); DataSize maxQueryTotalMemoryPerNode = nodeMemoryConfig.getMaxQueryTotalMemoryPerNode(); + queryMaxMemoryPerTask = nodeMemoryConfig.getMaxQueryTotalMemoryPerTask(); DataSize maxQuerySpillPerNode = nodeSpillConfig.getQueryMaxSpillPerNode(); queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes(); queryMaxTotalMemoryPerNode = maxQueryTotalMemoryPerNode.toBytes(); queryContexts = CacheBuilder.newBuilder().weakValues().build(CacheLoader.from( - queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQueryTotalMemoryPerNode, maxQuerySpillPerNode))); + queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQueryTotalMemoryPerNode, queryMaxMemoryPerTask, maxQuerySpillPerNode))); tasks = CacheBuilder.newBuilder().build(CacheLoader.from( taskId -> createSqlTask( @@ -184,12 +187,14 @@ private QueryContext createQueryContext( GcMonitor gcMonitor, DataSize maxQueryUserMemoryPerNode, DataSize maxQueryTotalMemoryPerNode, + Optional maxQueryMemoryPerTask, DataSize maxQuerySpillPerNode) { return new QueryContext( queryId, maxQueryUserMemoryPerNode, maxQueryTotalMemoryPerNode, + maxQueryMemoryPerTask, localMemoryManager.getGeneralPool(), gcMonitor, taskNotificationExecutor, @@ -406,11 +411,19 @@ private TaskInfo doUpdateTask( if (!queryContext.isMemoryLimitsInitialized()) { long sessionQueryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes(); long sessionQueryTotalMaxMemoryPerNode = getQueryMaxTotalMemoryPerNode(session).toBytes(); + + Optional effectiveQueryMaxMemoryPerTask = getQueryMaxTotalMemoryPerTask(session); + if (queryMaxMemoryPerTask.isPresent() && + (effectiveQueryMaxMemoryPerTask.isEmpty() || effectiveQueryMaxMemoryPerTask.get().toBytes() > queryMaxMemoryPerTask.get().toBytes())) { + effectiveQueryMaxMemoryPerTask = queryMaxMemoryPerTask; + } + // Session properties are only allowed to decrease memory limits, not increase them queryContext.initializeMemoryLimits( resourceOvercommit(session), min(sessionQueryMaxMemoryPerNode, queryMaxMemoryPerNode), - min(sessionQueryTotalMaxMemoryPerNode, queryMaxTotalMemoryPerNode)); + min(sessionQueryTotalMaxMemoryPerNode, queryMaxTotalMemoryPerNode), + effectiveQueryMaxMemoryPerTask); } sqlTask.recordHeartbeat(); diff --git a/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java b/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java index 42e2a92353ce..95b28245873c 100644 --- a/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java +++ b/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java @@ -21,6 +21,8 @@ import javax.validation.constraints.NotNull; +import java.util.Optional; + // This is separate from MemoryManagerConfig because it's difficult to test the default value of maxQueryMemoryPerNode @DefunctConfig("deprecated.legacy-system-pool-enabled") public class NodeMemoryConfig @@ -28,11 +30,14 @@ public class NodeMemoryConfig public static final long AVAILABLE_HEAP_MEMORY = Runtime.getRuntime().maxMemory(); public static final String QUERY_MAX_MEMORY_PER_NODE_CONFIG = "query.max-memory-per-node"; public static final String QUERY_MAX_TOTAL_MEMORY_PER_NODE_CONFIG = "query.max-total-memory-per-node"; + public static final String QUERY_MAX_TOTAL_MEMORY_PER_TASK_CONFIG = "query.max-total-memory-per-task"; private boolean isReservedPoolDisabled = true; private DataSize maxQueryMemoryPerNode = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.1)); + private Optional maxQueryTotalMemoryPerTask = Optional.empty(); + // This is a per-query limit for the user plus system allocations. private DataSize maxQueryTotalMemoryPerNode = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3)); private DataSize heapHeadroom = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3)); @@ -50,6 +55,20 @@ public NodeMemoryConfig setMaxQueryMemoryPerNode(DataSize maxQueryMemoryPerNode) return this; } + @NotNull + public Optional getMaxQueryTotalMemoryPerTask() + { + return maxQueryTotalMemoryPerTask; + } + + @Config(QUERY_MAX_TOTAL_MEMORY_PER_TASK_CONFIG) + @ConfigDescription("Sets total (user + system) memory limit enforced for a single task; there is no memory limit by default") + public NodeMemoryConfig setMaxQueryTotalMemoryPerTask(DataSize maxQueryTotalMemoryPerTask) + { + this.maxQueryTotalMemoryPerTask = Optional.ofNullable(maxQueryTotalMemoryPerTask); + return this; + } + @Deprecated @LegacyConfig(value = "experimental.reserved-pool-enabled", replacedBy = "experimental.reserved-pool-disabled") public void setReservedPoolEnabled(boolean reservedPoolEnabled) diff --git a/core/trino-main/src/main/java/io/trino/memory/QueryContext.java b/core/trino-main/src/main/java/io/trino/memory/QueryContext.java index 081cede1f0f6..03462432ec12 100644 --- a/core/trino-main/src/main/java/io/trino/memory/QueryContext.java +++ b/core/trino-main/src/main/java/io/trino/memory/QueryContext.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -51,6 +52,7 @@ import static io.trino.ExceededSpillLimitException.exceededPerQueryLocalLimit; import static io.trino.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; import static io.trino.operator.Operator.NOT_BLOCKED; +import static io.trino.operator.TaskContext.createTaskContext; import static java.lang.String.format; import static java.util.Map.Entry.comparingByValue; import static java.util.Objects.requireNonNull; @@ -78,6 +80,8 @@ public class QueryContext private long maxUserMemory; @GuardedBy("this") private long maxTotalMemory; + @GuardedBy("this") + private Optional maxTaskMemory; private final MemoryTrackingContext queryMemoryContext; @@ -91,6 +95,7 @@ public QueryContext( QueryId queryId, DataSize maxUserMemory, DataSize maxTotalMemory, + Optional maxTaskMemory, MemoryPool memoryPool, GcMonitor gcMonitor, Executor notificationExecutor, @@ -101,6 +106,7 @@ public QueryContext( this.queryId = requireNonNull(queryId, "queryId is null"); this.maxUserMemory = requireNonNull(maxUserMemory, "maxUserMemory is null").toBytes(); this.maxTotalMemory = requireNonNull(maxTotalMemory, "maxTotalMemory is null").toBytes(); + this.maxTaskMemory = requireNonNull(maxTaskMemory, "maxTaskMemory is null"); this.memoryPool = requireNonNull(memoryPool, "memoryPool is null"); this.gcMonitor = requireNonNull(gcMonitor, "gcMonitor is null"); this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); @@ -119,7 +125,7 @@ public boolean isMemoryLimitsInitialized() } // TODO: This method should be removed, and the correct limit set in the constructor. However, due to the way QueryContext is constructed the memory limit is not known in advance - public synchronized void initializeMemoryLimits(boolean resourceOverCommit, long maxUserMemory, long maxTotalMemory) + public synchronized void initializeMemoryLimits(boolean resourceOverCommit, long maxUserMemory, long maxTotalMemory, Optional maxTaskMemory) { checkArgument(maxUserMemory >= 0, "maxUserMemory must be >= 0, found: %s", maxUserMemory); checkArgument(maxTotalMemory >= 0, "maxTotalMemory must be >= 0, found: %s", maxTotalMemory); @@ -129,10 +135,12 @@ public synchronized void initializeMemoryLimits(boolean resourceOverCommit, long // The coordinator will kill the query if the cluster runs out of memory. this.maxUserMemory = memoryPool.getMaxBytes(); this.maxTotalMemory = memoryPool.getMaxBytes(); + this.maxTaskMemory = Optional.empty(); // disabled } else { this.maxUserMemory = maxUserMemory; this.maxTotalMemory = maxTotalMemory; + this.maxTaskMemory = maxTaskMemory; } memoryLimitsInitialized = true; } @@ -294,7 +302,7 @@ public TaskContext addTaskContext( boolean perOperatorCpuTimerEnabled, boolean cpuTimerEnabled) { - TaskContext taskContext = TaskContext.createTaskContext( + TaskContext taskContext = createTaskContext( this, taskStateMachine, gcMonitor, @@ -304,7 +312,8 @@ public TaskContext addTaskContext( queryMemoryContext.newMemoryTrackingContext(), notifyStatusChanged, perOperatorCpuTimerEnabled, - cpuTimerEnabled); + cpuTimerEnabled, + maxTaskMemory); taskContexts.put(taskStateMachine.getTaskId(), taskContext); return taskContext; } diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java b/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java new file mode 100644 index 000000000000..41edc9c6aa55 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.airlift.units.DataSize; +import io.trino.memory.context.MemoryAllocationValidator; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.units.DataSize.succinctBytes; +import static io.trino.ExceededMemoryLimitException.exceededTaskMemoryLimit; +import static java.lang.String.format; +import static java.util.Map.Entry.comparingByValue; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +// Keeps track of per-node memory usage of given task. Single instance is shared by multiple ValidatingLocalMemoryContext instances +// originating from single ValidatingAggregateContext. +public class TaskAllocationValidator + implements MemoryAllocationValidator +{ + private final long limitBytes; + @GuardedBy("this") + private long usedBytes; + @GuardedBy("this") + private final Map taggedAllocations = new HashMap<>(); + + public TaskAllocationValidator(DataSize memoryLimit) + { + this.limitBytes = requireNonNull(memoryLimit, "memoryLimit is null").toBytes(); + } + + @Override + public synchronized void reserveMemory(String allocationTag, long delta) + { + if (usedBytes + delta > limitBytes) { + verify(delta > 0, "exceeded limit with negative delta (%s); usedBytes=%s, limitBytes=%s", delta, usedBytes, limitBytes); + raiseLimitExceededFailure(allocationTag, delta); + } + usedBytes += delta; + taggedAllocations.merge(allocationTag, delta, Long::sum); + } + + private synchronized void raiseLimitExceededFailure(String currentAllocationTag, long currentAllocationDelta) + { + Map tmpTaggedAllocations = new HashMap<>(taggedAllocations); + // include current allocation in the output of top-consumers + tmpTaggedAllocations.merge(currentAllocationTag, currentAllocationDelta, Long::sum); + String topConsumers = tmpTaggedAllocations.entrySet().stream() + .sorted(comparingByValue(Comparator.reverseOrder())) + .limit(3) + .filter(e -> e.getValue() >= 0) + .collect(toImmutableMap(Map.Entry::getKey, e -> succinctBytes(e.getValue()))) + .toString(); + + String additionalInfo = format("Allocated: %s, Delta: %s, Top Consumers: %s", succinctBytes(usedBytes), succinctBytes(currentAllocationDelta), topConsumers); + throw exceededTaskMemoryLimit(DataSize.succinctBytes(limitBytes), additionalInfo); + } + + @Override + public synchronized boolean tryReserveMemory(String allocationTag, long delta) + { + if (usedBytes + delta > limitBytes) { + verify(delta > 0, "exceeded limit with negative delta (%s); usedBytes=%s, limitBytes=%s", delta, usedBytes, limitBytes); + return false; + } + usedBytes += delta; + return true; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java index fd08d73d9f96..0701a552fe95 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java @@ -33,7 +33,9 @@ import io.trino.memory.QueryContext; import io.trino.memory.QueryContextVisitor; import io.trino.memory.context.LocalMemoryContext; +import io.trino.memory.context.MemoryAllocationValidator; import io.trino.memory.context.MemoryTrackingContext; +import io.trino.memory.context.ValidatingAggregateContext; import io.trino.spi.predicate.Domain; import io.trino.sql.planner.LocalDynamicFiltersCollector; import io.trino.sql.planner.plan.DynamicFilterId; @@ -44,6 +46,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; @@ -123,7 +126,8 @@ public static TaskContext createTaskContext( MemoryTrackingContext taskMemoryContext, Runnable notifyStatusChanged, boolean perOperatorCpuTimerEnabled, - boolean cpuTimerEnabled) + boolean cpuTimerEnabled, + Optional maxMemory) { TaskContext taskContext = new TaskContext( queryContext, @@ -135,7 +139,8 @@ public static TaskContext createTaskContext( taskMemoryContext, notifyStatusChanged, perOperatorCpuTimerEnabled, - cpuTimerEnabled); + cpuTimerEnabled, + maxMemory); taskContext.initialize(); return taskContext; } @@ -150,7 +155,8 @@ private TaskContext( MemoryTrackingContext taskMemoryContext, Runnable notifyStatusChanged, boolean perOperatorCpuTimerEnabled, - boolean cpuTimerEnabled) + boolean cpuTimerEnabled, + Optional maxMemory) { this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null"); this.gcMonitor = requireNonNull(gcMonitor, "gcMonitor is null"); @@ -158,9 +164,21 @@ private TaskContext( this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); this.yieldExecutor = requireNonNull(yieldExecutor, "yieldExecutor is null"); this.session = session; - this.taskMemoryContext = requireNonNull(taskMemoryContext, "taskMemoryContext is null"); + + requireNonNull(taskMemoryContext, "taskMemoryContext is null"); + if (maxMemory.isPresent()) { + MemoryAllocationValidator memoryValidator = new TaskAllocationValidator(maxMemory.get()); + this.taskMemoryContext = new MemoryTrackingContext( + new ValidatingAggregateContext(taskMemoryContext.aggregateUserMemoryContext(), memoryValidator), + taskMemoryContext.aggregateRevocableMemoryContext(), + new ValidatingAggregateContext(taskMemoryContext.aggregateSystemMemoryContext(), memoryValidator)); + } + else { + this.taskMemoryContext = taskMemoryContext; + } + // Initialize the local memory contexts with the LazyOutputBuffer tag as LazyOutputBuffer will do the local memory allocations - taskMemoryContext.initializeLocalMemoryContexts(LazyOutputBuffer.class.getSimpleName()); + this.taskMemoryContext.initializeLocalMemoryContexts(LazyOutputBuffer.class.getSimpleName()); this.dynamicFiltersCollector = new DynamicFiltersCollector(notifyStatusChanged); this.localDynamicFiltersCollector = new LocalDynamicFiltersCollector(session); this.perOperatorCpuTimerEnabled = perOperatorCpuTimerEnabled; diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java index 64366c52f1e9..516197f3a3e8 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java @@ -27,6 +27,7 @@ import io.trino.spi.memory.MemoryPoolId; import io.trino.spiller.SpillSpaceTracker; +import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -144,6 +145,7 @@ public TaskContext build() queryId, queryMaxMemory, queryMaxTotalMemory, + Optional.empty(), memoryPool, GC_MONITOR, notificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 8dd5e83d03e5..f586042f4b6b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -196,6 +196,7 @@ public MockRemoteTask( QueryContext queryContext = new QueryContext(taskId.getQueryId(), DataSize.of(1, MEGABYTE), DataSize.of(2, MEGABYTE), + Optional.empty(), memoryPool, new TestingGcMonitor(), executor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java index 14eebbb832ea..c8ab0f6a70ab 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java @@ -43,6 +43,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ScheduledExecutorService; @@ -314,6 +315,7 @@ private QueryContext getOrCreateQueryContext(QueryId queryId) return queryContexts.computeIfAbsent(queryId, id -> new QueryContext(id, DataSize.of(1, MEGABYTE), DataSize.of(2, MEGABYTE), + Optional.empty(), memoryPool, new TestingGcMonitor(), executor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index ea962a0437cf..8be142f4bb00 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -345,6 +345,7 @@ private SqlTask createInitialTask() QueryContext queryContext = new QueryContext(new QueryId("query"), DataSize.of(1, MEGABYTE), DataSize.of(2, MEGABYTE), + Optional.empty(), new MemoryPool(new MemoryPoolId("test"), DataSize.of(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 46ab542263ce..ce6351779c2f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -598,6 +598,7 @@ private TaskContext newTestingTaskContext(ScheduledExecutorService taskNotificat new QueryId("queryid"), DataSize.of(1, MEGABYTE), DataSize.of(2, MEGABYTE), + Optional.empty(), new MemoryPool(new MemoryPoolId("test"), DataSize.of(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java index b857303477e9..862905f18f64 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -96,6 +97,7 @@ private void setUp(Supplier> driversSupplier) QueryContext queryContext = new QueryContext(new QueryId("query"), TEN_MEGABYTES, DataSize.of(20, MEGABYTE), + Optional.empty(), userPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java index 8307859864ec..1d1de4f74eb5 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java @@ -39,6 +39,7 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -94,12 +95,18 @@ public void tearDown() @BeforeMethod public void setUpTest() + { + setupTestWithLimits(queryMaxMemory, queryMaxTotalMemory, Optional.empty()); + } + + private void setupTestWithLimits(DataSize queryMaxMemory, DataSize queryMaxTotalMemory, Optional queryMaxTaskMemory) { memoryPool = new MemoryPool(new MemoryPoolId("test"), memoryPoolSize); queryContext = new QueryContext( new QueryId("test_query"), queryMaxMemory, queryMaxTotalMemory, + queryMaxTaskMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, @@ -156,6 +163,21 @@ public void testLocalTotalMemoryLimitExceeded() .hasMessage("Query exceeded per-node total memory limit of %1$s [Allocated: %1$s, Delta: 1B, Top Consumers: {test=%1$s}]", queryMaxTotalMemory); } + @Test + public void testTaskMemoryLimitExceeded() + { + DataSize taskMaxMemory = DataSize.of(1, GIGABYTE); + setupTestWithLimits(DataSize.of(2, GIGABYTE), DataSize.of(2, GIGABYTE), Optional.of(taskMaxMemory)); + LocalMemoryContext systemMemoryContext = operatorContext.newLocalSystemMemoryContext("test"); + systemMemoryContext.setBytes(100); + assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, 100, 0); + systemMemoryContext.setBytes(taskMaxMemory.toBytes()); + assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 0, taskMaxMemory.toBytes(), 0); + assertThatThrownBy(() -> systemMemoryContext.setBytes(taskMaxMemory.toBytes() + 1)) + .isInstanceOf(ExceededMemoryLimitException.class) + .hasMessage("Query exceeded per-task total memory limit of %1$s [Allocated: %s, Delta: 1B, Top Consumers: {test=%s}]", taskMaxMemory, DataSize.succinctBytes(taskMaxMemory.toBytes() + 1)); + } + @Test public void testLocalSystemAllocations() { diff --git a/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java b/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java index 1ee999c20e73..999fb8b88c34 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java @@ -23,6 +23,7 @@ import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.memory.NodeMemoryConfig.AVAILABLE_HEAP_MEMORY; public class TestNodeMemoryConfig @@ -33,6 +34,7 @@ public void testDefaults() assertRecordedDefaults(recordDefaults(NodeMemoryConfig.class) .setMaxQueryMemoryPerNode(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.1))) .setMaxQueryTotalMemoryPerNode(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3))) + .setMaxQueryTotalMemoryPerTask(null) .setHeapHeadroom(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3))) .setReservedPoolDisabled(true)); } @@ -43,6 +45,7 @@ public void testExplicitPropertyMappings() Map properties = new ImmutableMap.Builder() .put("query.max-memory-per-node", "1GB") .put("query.max-total-memory-per-node", "3GB") + .put("query.max-total-memory-per-task", "200MB") .put("memory.heap-headroom-per-node", "1GB") .put("experimental.reserved-pool-disabled", "false") .build(); @@ -50,6 +53,7 @@ public void testExplicitPropertyMappings() NodeMemoryConfig expected = new NodeMemoryConfig() .setMaxQueryMemoryPerNode(DataSize.of(1, GIGABYTE)) .setMaxQueryTotalMemoryPerNode(DataSize.of(3, GIGABYTE)) + .setMaxQueryTotalMemoryPerTask(DataSize.of(200, MEGABYTE)) .setHeapHeadroom(DataSize.of(1, GIGABYTE)) .setReservedPoolDisabled(false); diff --git a/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java b/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java index 66a27fce52cd..a60855b78e0c 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java @@ -31,6 +31,7 @@ import org.testng.annotations.Test; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; import static io.airlift.concurrent.Threads.threadsNamed; @@ -76,6 +77,7 @@ public void testSetMemoryPool(boolean useReservedPool) new QueryId("query"), DataSize.ofBytes(10), DataSize.ofBytes(20), + Optional.empty(), new MemoryPool(GENERAL_POOL, DataSize.ofBytes(10)), new TestingGcMonitor(), localQueryRunner.getExecutor(), @@ -141,6 +143,7 @@ private static QueryContext createQueryContext(QueryId queryId, MemoryPool gener return new QueryContext(queryId, DataSize.ofBytes(10_000), DataSize.ofBytes(10_000), + Optional.empty(), generalPool, new TestingGcMonitor(), TEST_EXECUTOR, diff --git a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java index 0d325f32d750..7112aa5d40cf 100644 --- a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java +++ b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java @@ -27,6 +27,7 @@ import java.util.LinkedList; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Function; @@ -83,6 +84,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< queryId, DataSize.of(512, MEGABYTE), DataSize.of(1024, MEGABYTE), + Optional.empty(), memoryPool, new TestingGcMonitor(), EXECUTOR, diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 34222a4dc5bf..1ccf1c1cd286 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -169,6 +169,7 @@ public enum StandardErrorCode EXCEEDED_LOCAL_MEMORY_LIMIT(131079, INSUFFICIENT_RESOURCES), ADMINISTRATIVELY_PREEMPTED(131080, INSUFFICIENT_RESOURCES), EXCEEDED_SCAN_LIMIT(131081, INSUFFICIENT_RESOURCES), + EXCEEDED_TASK_MEMORY_LIMIT(131082, INSUFFICIENT_RESOURCES), /**/; diff --git a/lib/trino-memory-context/pom.xml b/lib/trino-memory-context/pom.xml index 3751d4d19408..cdcc8e5d65e9 100644 --- a/lib/trino-memory-context/pom.xml +++ b/lib/trino-memory-context/pom.xml @@ -18,6 +18,11 @@ + + io.airlift + log + + io.airlift units @@ -41,6 +46,12 @@ test + + org.assertj + assertj-core + test + + org.testng testng diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/AbstractAggregatedMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/AbstractAggregatedMemoryContext.java index 08003679582c..8dbb66b16225 100644 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/AbstractAggregatedMemoryContext.java +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/AbstractAggregatedMemoryContext.java @@ -87,7 +87,7 @@ synchronized void addBytes(long bytes) usedBytes = addExact(usedBytes, bytes); } - abstract ListenableFuture updateBytes(String allocationTag, long bytes); + abstract ListenableFuture updateBytes(String allocationTag, long delta); abstract boolean tryUpdateBytes(String allocationTag, long delta); diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ChildAggregatedMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ChildAggregatedMemoryContext.java index 61bcb49d9bcb..36279e9fa56e 100644 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ChildAggregatedMemoryContext.java +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ChildAggregatedMemoryContext.java @@ -31,12 +31,12 @@ class ChildAggregatedMemoryContext } @Override - synchronized ListenableFuture updateBytes(String allocationTag, long bytes) + synchronized ListenableFuture updateBytes(String allocationTag, long delta) { checkState(!isClosed(), "ChildAggregatedMemoryContext is already closed"); // update the parent before updating usedBytes as it may throw a runtime exception (e.g., ExceededMemoryLimitException) - ListenableFuture future = parentMemoryContext.updateBytes(allocationTag, bytes); - addBytes(bytes); + ListenableFuture future = parentMemoryContext.updateBytes(allocationTag, delta); + addBytes(delta); return future; } diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java new file mode 100644 index 000000000000..316a2a0b7b16 --- /dev/null +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.memory.context; + +public interface MemoryAllocationValidator +{ + MemoryAllocationValidator NO_MEMORY_VALIDATION = new MemoryAllocationValidator() { + @Override + public void reserveMemory(String allocationTag, long delta) {} + + @Override + public boolean tryReserveMemory(String allocationTag, long delta) + { + return true; + } + }; + + /** + * Check if memory can be reserved. Account for reserved memory if reservation is possible. Throw exception otherwise. + */ + void reserveMemory(String allocationTag, long delta); + + /** + * Check if memory can be reserved. Account for reserved memory if reservation is possible and return true. Return false otherwise. + */ + boolean tryReserveMemory(String allocationTag, long delta); +} diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/RootAggregatedMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/RootAggregatedMemoryContext.java index 24acdb67e0f0..4efc11331821 100644 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/RootAggregatedMemoryContext.java +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/RootAggregatedMemoryContext.java @@ -31,11 +31,11 @@ class RootAggregatedMemoryContext } @Override - synchronized ListenableFuture updateBytes(String allocationTag, long bytes) + synchronized ListenableFuture updateBytes(String allocationTag, long delta) { checkState(!isClosed(), "RootAggregatedMemoryContext is already closed"); - ListenableFuture future = reservationHandler.reserveMemory(allocationTag, bytes); - addBytes(bytes); + ListenableFuture future = reservationHandler.reserveMemory(allocationTag, delta); + addBytes(delta); // make sure we never block queries below guaranteedMemory if (getBytes() < guaranteedMemory) { future = NOT_BLOCKED; diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/SimpleAggregatedMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/SimpleAggregatedMemoryContext.java index 7c804dc565ed..ab8ac0d108cd 100644 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/SimpleAggregatedMemoryContext.java +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/SimpleAggregatedMemoryContext.java @@ -24,10 +24,10 @@ class SimpleAggregatedMemoryContext extends AbstractAggregatedMemoryContext { @Override - synchronized ListenableFuture updateBytes(String allocationTag, long bytes) + synchronized ListenableFuture updateBytes(String allocationTag, long delta) { checkState(!isClosed(), "SimpleAggregatedMemoryContext is already closed"); - addBytes(bytes); + addBytes(delta); return NOT_BLOCKED; } diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java new file mode 100644 index 000000000000..e9d854528cf9 --- /dev/null +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.memory.context; + +import static java.util.Objects.requireNonNull; + +public class ValidatingAggregateContext + implements AggregatedMemoryContext +{ + private final AggregatedMemoryContext delegate; + private final MemoryAllocationValidator memoryValidator; + + public ValidatingAggregateContext(AggregatedMemoryContext delegate, MemoryAllocationValidator memoryValidator) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.memoryValidator = requireNonNull(memoryValidator, "memoryValidator is null"); + } + + @Override + public AggregatedMemoryContext newAggregatedMemoryContext() + { + return new ValidatingAggregateContext(delegate.newAggregatedMemoryContext(), memoryValidator); + } + + @Override + public LocalMemoryContext newLocalMemoryContext(String allocationTag) + { + return new ValidatingLocalMemoryContext(delegate.newLocalMemoryContext(allocationTag), allocationTag, memoryValidator); + } + + @Override + public long getBytes() + { + return delegate.getBytes(); + } + + @Override + public void close() + { + delegate.close(); + } +} diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java new file mode 100644 index 000000000000..be2ee36f097f --- /dev/null +++ b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.memory.context; + +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; + +import static java.util.Objects.requireNonNull; + +public class ValidatingLocalMemoryContext + implements LocalMemoryContext +{ + private static final Logger log = Logger.get(ValidatingLocalMemoryContext.class); + + private final LocalMemoryContext delegate; + private final String allocationTag; + private final MemoryAllocationValidator memoryValidator; + + public ValidatingLocalMemoryContext(LocalMemoryContext delegate, String allocationTag, MemoryAllocationValidator memoryValidator) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.allocationTag = requireNonNull(allocationTag, "allocationTag is null"); + this.memoryValidator = requireNonNull(memoryValidator, "memoryValidator is null"); + } + + @Override + public long getBytes() + { + return delegate.getBytes(); + } + + @Override + public ListenableFuture setBytes(long bytes) + { + long delta = bytes - delegate.getBytes(); + + // first consult validator if allocation is possible + memoryValidator.reserveMemory(allocationTag, delta); + + // update the parent before updating usedBytes as it may throw a runtime exception (e.g., ExceededMemoryLimitException) + try { + // do actual allocation + return delegate.setBytes(bytes); + } + catch (Exception e) { + revertReservationInValidatorSuppressing(allocationTag, delta, e); + throw e; + } + } + + @Override + public boolean trySetBytes(long bytes) + { + long delta = bytes - delegate.getBytes(); + + if (!memoryValidator.tryReserveMemory(allocationTag, delta)) { + return false; + } + + try { + if (delegate.trySetBytes(bytes)) { + return true; + } + } + catch (Exception e) { + revertReservationInValidatorSuppressing(allocationTag, delta, e); + throw e; + } + + revertReservationInValidator(allocationTag, delta); + return false; + } + + @Override + public void close() + { + delegate.close(); + } + + private void revertReservationInValidatorSuppressing(String allocationTag, long delta, Exception revertCause) + { + try { + revertReservationInValidator(allocationTag, delta); + } + catch (Exception suppressed) { + log.warn(suppressed, "Could not rollback memory reservation within allocation validator"); + if (suppressed != revertCause) { + revertCause.addSuppressed(suppressed); + } + } + } + + private void revertReservationInValidator(String allocationTag, long delta) + { + memoryValidator.reserveMemory(allocationTag, -delta); + } +} diff --git a/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java b/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java index 7aaaa9d3e534..4b758cc61b6a 100644 --- a/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java +++ b/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java @@ -22,6 +22,7 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; @@ -100,7 +101,7 @@ public void testTryReserve() } @Test - public void testHieararchicalMemoryContexts() + public void testHierarchicalMemoryContexts() { TestMemoryReservationHandler reservationHandler = new TestMemoryReservationHandler(1_000); AggregatedMemoryContext parentContext = newRootAggregatedMemoryContext(reservationHandler, GUARANTEED_MEMORY); @@ -164,16 +165,135 @@ public void testClosedAggregateMemoryContext() localContext.setBytes(100); } + @Test + public void testValidatingAggregateContext() + { + TestMemoryReservationHandler reservationHandler = new TestMemoryReservationHandler(1_000, true); + AggregatedMemoryContext rootContext = newRootAggregatedMemoryContext(reservationHandler, GUARANTEED_MEMORY); + + AggregatedMemoryContext childContext = new ValidatingAggregateContext(rootContext, new TestAllocationValidator(500)); + + LocalMemoryContext localContext = childContext.newLocalMemoryContext("test"); + + assertEquals(localContext.setBytes(500), NOT_BLOCKED); + assertEquals(localContext.getBytes(), 500); + assertEquals(rootContext.getBytes(), 500); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // reserve above validator limit + assertThatThrownBy(() -> localContext.setBytes(501)).hasMessage("limit exceeded"); + assertEquals(localContext.getBytes(), 500); + assertEquals(rootContext.getBytes(), 500); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // try reserve above validator limit + assertFalse(localContext.trySetBytes(501)); + assertEquals(localContext.getBytes(), 500); + assertEquals(rootContext.getBytes(), 500); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // unreserve a bit + assertEquals(localContext.setBytes(400), NOT_BLOCKED); + assertEquals(localContext.getBytes(), 400); + assertEquals(rootContext.getBytes(), 400); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // unreserve a bit using trySetBytes + assertTrue(localContext.trySetBytes(300)); + assertEquals(localContext.getBytes(), 300); + assertEquals(rootContext.getBytes(), 300); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // another context based directly on rootContext + LocalMemoryContext anotherLocalContext = rootContext.newLocalMemoryContext("another"); + + assertEquals(anotherLocalContext.setBytes(650), NOT_BLOCKED); + // total reservation is 950 at root level now + assertEquals(localContext.getBytes(), 300); + assertEquals(anotherLocalContext.getBytes(), 650); + assertEquals(rootContext.getBytes(), 950); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // exceed root context limit but be within validator boundaries + assertThatThrownBy(() -> localContext.setBytes(400)).hasMessage("out of memory"); + assertEquals(localContext.getBytes(), 300); + assertEquals(anotherLocalContext.getBytes(), 650); + assertEquals(rootContext.getBytes(), 950); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // exceed root context limit but be within validator boundaries using trySetBytes + assertFalse(localContext.trySetBytes(400)); + assertEquals(localContext.getBytes(), 300); + assertEquals(anotherLocalContext.getBytes(), 650); + assertEquals(rootContext.getBytes(), 950); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // if we free space in root context we can still allocate up to validator imposed limit + assertEquals(anotherLocalContext.setBytes(499), NOT_BLOCKED); + + // reserve using setBytes + assertEquals(localContext.setBytes(400), NOT_BLOCKED); + assertEquals(localContext.getBytes(), 400); + assertEquals(anotherLocalContext.getBytes(), 499); + assertEquals(rootContext.getBytes(), 899); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + + // reserve using trySetBytes + assertEquals(localContext.setBytes(500), NOT_BLOCKED); + assertEquals(localContext.getBytes(), 500); + assertEquals(anotherLocalContext.getBytes(), 499); + assertEquals(rootContext.getBytes(), 999); + assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); + } + + private static class TestAllocationValidator + implements MemoryAllocationValidator + { + private final long limit; + private long reserved; + + public TestAllocationValidator(long limit) + { + this.limit = limit; + } + + @Override + public void reserveMemory(String allocationTag, long delta) + { + if (reserved + delta > limit) { + throw new IllegalArgumentException("limit exceeded"); + } + reserved = reserved + delta; + } + + @Override + public boolean tryReserveMemory(String allocationTag, long delta) + { + if (reserved + delta > limit) { + return false; + } + reserved = reserved + delta; + return true; + } + } + private static class TestMemoryReservationHandler implements MemoryReservationHandler { private long reservation; private final long maxMemory; + private final boolean throwWhenExceeded; private SettableFuture future; public TestMemoryReservationHandler(long maxMemory) + { + this(maxMemory, false); + } + + public TestMemoryReservationHandler(long maxMemory, boolean throwWhenExceeded) { this.maxMemory = maxMemory; + this.throwWhenExceeded = throwWhenExceeded; } public long getReservation() @@ -184,6 +304,9 @@ public long getReservation() @Override public ListenableFuture reserveMemory(String allocationTag, long delta) { + if (delta > 0 && reservation + delta > maxMemory && throwWhenExceeded) { + throw new IllegalStateException("out of memory"); + } reservation += delta; if (delta >= 0) { if (reservation >= maxMemory) { diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java index ea2dff1f9266..08235b6d17b9 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java @@ -302,6 +302,7 @@ protected Map runOnce() new QueryId("test"), DataSize.of(256, MEGABYTE), DataSize.of(512, MEGABYTE), + Optional.empty(), memoryPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), diff --git a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java index f7758a0e674e..6aeb566e47b1 100644 --- a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java +++ b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java @@ -74,6 +74,7 @@ public List execute(@Language("SQL") String query) new QueryId("test"), DataSize.of(1, GIGABYTE), DataSize.of(2, GIGABYTE), + Optional.empty(), memoryPool, new TestingGcMonitor(), localQueryRunner.getExecutor(),