From 9f8d86bae2cc869fc7c8d169720c8e0529f01a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Tue, 14 Dec 2021 16:05:32 +0100 Subject: [PATCH 01/11] Fix statement ordering --- .../src/main/java/io/trino/execution/NodeTaskMap.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java b/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java index edb1141e1ac1..5999d75a31a0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java +++ b/core/trino-main/src/main/java/io/trino/execution/NodeTaskMap.java @@ -97,16 +97,17 @@ private PartitionedSplitsInfo getPartitionedSplitsInfo() private void addTask(RemoteTask task) { if (remoteTasks.add(task)) { + // Check if task state is already done before adding the listener + if (task.getTaskStatus().getState().isDone()) { + remoteTasks.remove(task); + return; + } + task.addStateChangeListener(taskStatus -> { if (taskStatus.getState().isDone()) { remoteTasks.remove(task); } }); - - // Check if task state is already done before adding the listener - if (task.getTaskStatus().getState().isDone()) { - remoteTasks.remove(task); - } } } From 634b038bb3d17041c236e7fb67bde801aaf8a973 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Mon, 24 Jan 2022 13:08:55 +0100 Subject: [PATCH 02/11] Rename session property getter --- .../src/main/java/io/trino/SystemSessionProperties.java | 2 +- .../src/main/java/io/trino/execution/SqlTaskManager.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index c9ffc8507f77..998fbe04ce77 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -1178,7 +1178,7 @@ public static DataSize getQueryMaxMemoryPerNode(Session session) return session.getSystemProperty(QUERY_MAX_MEMORY_PER_NODE, DataSize.class); } - public static Optional getQueryMaxTotalMemoryPerTask(Session session) + public static Optional getQueryMaxMemoryPerTask(Session session) { return Optional.ofNullable(session.getSystemProperty(QUERY_MAX_MEMORY_PER_TASK, DataSize.class)); } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index 2c8ae73eff1a..0beb3ac9c66d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -71,7 +71,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.Threads.threadsNamed; import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode; -import static io.trino.SystemSessionProperties.getQueryMaxTotalMemoryPerTask; +import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerTask; import static io.trino.SystemSessionProperties.resourceOvercommit; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.execution.SqlTask.createSqlTask; @@ -401,7 +401,7 @@ private TaskInfo doUpdateTask( if (!queryContext.isMemoryLimitsInitialized()) { long sessionQueryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes(); - Optional effectiveQueryMaxMemoryPerTask = getQueryMaxTotalMemoryPerTask(session); + Optional effectiveQueryMaxMemoryPerTask = getQueryMaxMemoryPerTask(session); if (queryMaxMemoryPerTask.isPresent() && (effectiveQueryMaxMemoryPerTask.isEmpty() || effectiveQueryMaxMemoryPerTask.get().toBytes() > queryMaxMemoryPerTask.get().toBytes())) { effectiveQueryMaxMemoryPerTask = queryMaxMemoryPerTask; From af08f9a0ebbd11977a354a4d86e8028506017f27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Wed, 2 Mar 2022 16:50:56 -0800 Subject: [PATCH 03/11] Remove reduntant 'else' clauses --- .../trino/execution/scheduler/StageTaskSourceFactory.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index 84ae94b7826c..387f8dd35ef7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -130,14 +130,14 @@ public TaskSource create( if (partitioning.equals(SINGLE_DISTRIBUTION)) { return SingleDistributionTaskSource.create(fragment, exchangeSourceHandles); } - else if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { + if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { return ArbitraryDistributionTaskSource.create( fragment, sourceExchanges, exchangeSourceHandles, getFaultTolerantExecutionTargetTaskInputSize(session)); } - else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) { + if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnectorId().isPresent()) { return HashDistributionTaskSource.create( session, fragment, @@ -151,7 +151,7 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnect getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), getFaultTolerantExecutionTargetTaskInputSize(session)); } - else if (partitioning.equals(SOURCE_DISTRIBUTION)) { + if (partitioning.equals(SOURCE_DISTRIBUTION)) { return SourceDistributionTaskSource.create( session, fragment, From 32b9f2453a5bdc4230ab0ac0abfe0aaa8f821135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Fri, 10 Dec 2021 15:51:24 +0100 Subject: [PATCH 04/11] Introduce NodeAllocatorService --- .../io/trino/execution/SqlQueryExecution.java | 9 + .../scheduler/FixedCountNodeAllocator.java | 205 -------- .../FixedCountNodeAllocatorService.java | 287 +++++++++++ .../execution/scheduler/NodeAllocator.java | 2 - .../scheduler/NodeAllocatorService.java | 21 + .../scheduler/SqlQueryScheduler.java | 22 +- .../io/trino/server/CoordinatorModule.java | 5 + .../TestFaultTolerantStageScheduler.java | 469 +++++++++--------- .../TestFixedCountNodeAllocator.java | 60 ++- 9 files changed, 617 insertions(+), 463 deletions(-) delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 99aefd23ed06..4d16ef25e6fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -25,6 +25,7 @@ import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.SqlQueryScheduler; @@ -99,6 +100,7 @@ public class SqlQueryExecution private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -132,6 +134,7 @@ private SqlQueryExecution( SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, List planOptimizers, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -159,6 +162,7 @@ private SqlQueryExecution( this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -497,6 +501,7 @@ private void planDistribution(PlanRoot plan) plan.getRoot(), nodePartitioningManager, nodeScheduler, + nodeAllocatorService, remoteTaskFactory, plan.isSummarizeTaskInfos(), scheduleSplitBatchSize, @@ -698,6 +703,7 @@ public static class SqlQueryExecutionFactory private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -724,6 +730,7 @@ public static class SqlQueryExecutionFactory SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, PlanOptimizersFactory planOptimizersFactory, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -751,6 +758,7 @@ public static class SqlQueryExecutionFactory this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -790,6 +798,7 @@ public QueryExecution createQueryExecution( splitSourceFactory, nodePartitioningManager, nodeScheduler, + nodeAllocatorService, planOptimizers, planFragmenter, remoteTaskFactory, diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java deleted file mode 100644 index 5507285011a8..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocator.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.trino.Session; -import io.trino.connector.CatalogName; -import io.trino.metadata.InternalNode; -import io.trino.spi.TrinoException; - -import javax.annotation.concurrent.GuardedBy; - -import java.util.HashMap; -import java.util.IdentityHashMap; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.util.concurrent.Futures.immediateFailedFuture; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; -import static java.util.Comparator.comparing; -import static java.util.Objects.requireNonNull; - -public class FixedCountNodeAllocator - implements NodeAllocator -{ - private final NodeScheduler nodeScheduler; - - private final Session session; - private final int maximumAllocationsPerNode; - - @GuardedBy("this") - private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); - - @GuardedBy("this") - private final Map allocationCountMap = new HashMap<>(); - - @GuardedBy("this") - private final LinkedList pendingAcquires = new LinkedList<>(); - - public FixedCountNodeAllocator( - NodeScheduler nodeScheduler, - Session session, - int maximumAllocationsPerNode) - { - this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); - this.session = requireNonNull(session, "session is null"); - this.maximumAllocationsPerNode = maximumAllocationsPerNode; - } - - @Override - public synchronized ListenableFuture acquire(NodeRequirements requirements) - { - try { - Optional node = tryAcquireNode(requirements); - if (node.isPresent()) { - return immediateFuture(node.get()); - } - } - catch (RuntimeException e) { - return immediateFailedFuture(e); - } - - SettableFuture future = SettableFuture.create(); - PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); - pendingAcquires.add(pendingAcquire); - - return future; - } - - @Override - public void release(InternalNode node) - { - releaseNodeInternal(node); - processPendingAcquires(); - } - - @Override - public void updateNodes() - { - processPendingAcquires(); - } - - private synchronized Optional tryAcquireNode(NodeRequirements requirements) - { - NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); - - List nodes = nodeSelector.allNodes(); - if (nodes.isEmpty()) { - throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); - } - - List nodesMatchingRequirements = nodes.stream() - .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().contains(node.getHostAndPort())) - .collect(toImmutableList()); - - if (nodesMatchingRequirements.isEmpty()) { - throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); - } - - Optional selectedNode = nodesMatchingRequirements.stream() - .filter(node -> allocationCountMap.getOrDefault(node, 0) < maximumAllocationsPerNode) - .min(comparing(node -> allocationCountMap.getOrDefault(node, 0))); - - if (selectedNode.isEmpty()) { - return Optional.empty(); - } - - allocationCountMap.compute(selectedNode.get(), (key, value) -> value == null ? 1 : value + 1); - return selectedNode; - } - - private synchronized void releaseNodeInternal(InternalNode node) - { - int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); - checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); - } - - private void processPendingAcquires() - { - verify(!Thread.holdsLock(this)); - - IdentityHashMap assignedNodes = new IdentityHashMap<>(); - IdentityHashMap failures = new IdentityHashMap<>(); - synchronized (this) { - Iterator iterator = pendingAcquires.iterator(); - while (iterator.hasNext()) { - PendingAcquire pendingAcquire = iterator.next(); - if (pendingAcquire.getFuture().isCancelled()) { - iterator.remove(); - continue; - } - try { - Optional node = tryAcquireNode(pendingAcquire.getNodeRequirements()); - if (node.isPresent()) { - iterator.remove(); - assignedNodes.put(pendingAcquire, node.get()); - } - } - catch (RuntimeException e) { - iterator.remove(); - failures.put(pendingAcquire, e); - } - } - } - - assignedNodes.forEach((pendingAcquire, node) -> { - SettableFuture future = pendingAcquire.getFuture(); - future.set(node); - if (future.isCancelled()) { - releaseNodeInternal(node); - } - }); - - failures.forEach((pendingAcquire, failure) -> { - SettableFuture future = pendingAcquire.getFuture(); - future.setException(failure); - }); - } - - @Override - public synchronized void close() - { - } - - private static class PendingAcquire - { - private final NodeRequirements nodeRequirements; - private final SettableFuture future; - - private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) - { - this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); - this.future = requireNonNull(future, "future is null"); - } - - public NodeRequirements getNodeRequirements() - { - return nodeRequirements; - } - - public SettableFuture getFuture() - { - return future; - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java new file mode 100644 index 000000000000..2faf06de4b40 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java @@ -0,0 +1,287 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.metadata.InternalNode; +import io.trino.spi.TrinoException; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Sets.newConcurrentHashSet; +import static com.google.common.util.concurrent.Futures.immediateFailedFuture; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; + +/** + * A simplistic node allocation service which only limits number of allocations per node within each + * {@link FixedCountNodeAllocator} instance. Each allocator will allow each node to be acquired up to {@link FixedCountNodeAllocatorService#MAXIMUM_ALLOCATIONS_PER_NODE} + * times at the same time. + */ +@ThreadSafe +public class FixedCountNodeAllocatorService + implements NodeAllocatorService +{ + private static final Logger log = Logger.get(FixedCountNodeAllocatorService.class); + + // Single FixedCountNodeAllocator will allow for at most MAXIMUM_ALLOCATIONS_PER_NODE. + // If we reach this state subsequent calls to acquire will return blocked lease. + private static final int MAXIMUM_ALLOCATIONS_PER_NODE = 1; // TODO make configurable? + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1, daemonThreadsNamed("fixed-count-node-allocator")); + private final NodeScheduler nodeScheduler; + + private final Set allocators = newConcurrentHashSet(); + private final AtomicBoolean started = new AtomicBoolean(); + + @Inject + public FixedCountNodeAllocatorService(NodeScheduler nodeScheduler) + { + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + } + + @PostConstruct + public void start() + { + if (!started.compareAndSet(false, true)) { + // already started + return; + } + executor.scheduleWithFixedDelay(() -> { + try { + updateNodes(); + } + catch (Throwable e) { + // ignore to avoid getting unscheduled + log.warn(e, "Error updating nodes"); + } + }, 5, 5, TimeUnit.SECONDS); + } + + @PreDestroy + public void stop() + { + executor.shutdownNow(); + } + + @VisibleForTesting + void updateNodes() + { + allocators.forEach(FixedCountNodeAllocator::updateNodes); + } + + @Override + public NodeAllocator getNodeAllocator(Session session) + { + requireNonNull(session, "session is null"); + return getNodeAllocator(session, MAXIMUM_ALLOCATIONS_PER_NODE); + } + + @VisibleForTesting + NodeAllocator getNodeAllocator(Session session, int maximumAllocationsPerNode) + { + FixedCountNodeAllocator allocator = new FixedCountNodeAllocator(session, maximumAllocationsPerNode); + allocators.add(allocator); + return allocator; + } + + private class FixedCountNodeAllocator + implements NodeAllocator + { + private final Session session; + private final int maximumAllocationsPerNode; + + @GuardedBy("this") + private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); + + @GuardedBy("this") + private final Map allocationCountMap = new HashMap<>(); + + @GuardedBy("this") + private final List pendingAcquires = new LinkedList<>(); + + public FixedCountNodeAllocator( + Session session, + int maximumAllocationsPerNode) + { + this.session = requireNonNull(session, "session is null"); + this.maximumAllocationsPerNode = maximumAllocationsPerNode; + } + + @Override + public synchronized ListenableFuture acquire(NodeRequirements requirements) + { + try { + Optional node = tryAcquireNode(requirements); + if (node.isPresent()) { + return immediateFuture(node.get()); + } + } + catch (RuntimeException e) { + return immediateFailedFuture(e); + } + + SettableFuture future = SettableFuture.create(); + PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); + pendingAcquires.add(pendingAcquire); + + return future; + } + + @Override + public void release(InternalNode node) + { + releaseNodeInternal(node); + processPendingAcquires(); + } + + public void updateNodes() + { + processPendingAcquires(); + } + + private synchronized Optional tryAcquireNode(NodeRequirements requirements) + { + NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); + + List nodes = nodeSelector.allNodes(); + if (nodes.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + List nodesMatchingRequirements = nodes.stream() + .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().contains(node.getHostAndPort())) + .collect(toImmutableList()); + + if (nodesMatchingRequirements.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + Optional selectedNode = nodesMatchingRequirements.stream() + .filter(node -> allocationCountMap.getOrDefault(node, 0) < maximumAllocationsPerNode) + .min(comparing(node -> allocationCountMap.getOrDefault(node, 0))); + + if (selectedNode.isEmpty()) { + return Optional.empty(); + } + + allocationCountMap.compute(selectedNode.get(), (key, value) -> value == null ? 1 : value + 1); + return selectedNode; + } + + private synchronized void releaseNodeInternal(InternalNode node) + { + int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); + checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); + } + + private void processPendingAcquires() + { + verify(!Thread.holdsLock(this)); + + IdentityHashMap assignedNodes = new IdentityHashMap<>(); + IdentityHashMap failures = new IdentityHashMap<>(); + synchronized (this) { + Iterator iterator = pendingAcquires.iterator(); + while (iterator.hasNext()) { + PendingAcquire pendingAcquire = iterator.next(); + if (pendingAcquire.getFuture().isCancelled()) { + iterator.remove(); + continue; + } + try { + Optional node = tryAcquireNode(pendingAcquire.getNodeRequirements()); + if (node.isPresent()) { + iterator.remove(); + assignedNodes.put(pendingAcquire, node.get()); + } + } + catch (RuntimeException e) { + iterator.remove(); + failures.put(pendingAcquire, e); + } + } + } + + // set futures outside of critical section + assignedNodes.forEach((pendingAcquire, node) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.set(node); + if (future.isCancelled()) { + releaseNodeInternal(node); + } + }); + + failures.forEach((pendingAcquire, failure) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.setException(failure); + }); + } + + @Override + public synchronized void close() + { + allocators.remove(this); + } + } + + private static class PendingAcquire + { + private final NodeRequirements nodeRequirements; + private final SettableFuture future; + + private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) + { + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + this.future = requireNonNull(future, "future is null"); + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + public SettableFuture getFuture() + { + return future; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java index 778c059982e6..f7aaa038f934 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java @@ -25,8 +25,6 @@ public interface NodeAllocator void release(InternalNode node); - void updateNodes(); - @Override void close(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java new file mode 100644 index 000000000000..faea8c229f25 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocatorService.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.Session; + +public interface NodeAllocatorService +{ + NodeAllocator getNodeAllocator(Session session); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index 85647ef02433..4baf816283fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -97,7 +97,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -175,6 +174,7 @@ public class SqlQueryScheduler private final QueryStateMachine queryStateMachine; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final NodeAllocatorService nodeAllocatorService; private final int splitBatchSize; private final ExecutorService executor; private final ScheduledExecutorService schedulerExecutor; @@ -210,6 +210,7 @@ public SqlQueryScheduler( SubPlan plan, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + NodeAllocatorService nodeAllocatorService, RemoteTaskFactory remoteTaskFactory, boolean summarizeTaskInfo, int splitBatchSize, @@ -231,6 +232,7 @@ public SqlQueryScheduler( this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.splitBatchSize = splitBatchSize; this.executor = requireNonNull(queryExecutor, "queryExecutor is null"); this.schedulerExecutor = requireNonNull(schedulerExecutor, "schedulerExecutor is null"); @@ -342,7 +344,7 @@ private synchronized Optional createDistributedStage maxRetryAttempts, schedulerExecutor, schedulerStats, - nodeScheduler); + nodeAllocatorService); break; case QUERY: case NONE: @@ -1727,7 +1729,6 @@ private static class FaultTolerantDistributedStagesScheduler private final List schedulers; private final SplitSchedulerStats schedulerStats; private final NodeAllocator nodeAllocator; - private final ScheduledFuture nodeUpdateTask; private final AtomicBoolean started = new AtomicBoolean(); @@ -1743,7 +1744,7 @@ public static FaultTolerantDistributedStagesScheduler create( int retryAttempts, ScheduledExecutorService scheduledExecutorService, SplitSchedulerStats schedulerStats, - NodeScheduler nodeScheduler) + NodeAllocatorService nodeAllocatorService) { taskDescriptorStorage.initialize(queryStateMachine.getQueryId()); queryStateMachine.addStateChangeListener(state -> { @@ -1760,9 +1761,7 @@ public static FaultTolerantDistributedStagesScheduler create( ImmutableList.Builder schedulers = ImmutableList.builder(); Map exchanges = new HashMap<>(); - - FixedCountNodeAllocator nodeAllocator = new FixedCountNodeAllocator(nodeScheduler, session, 1); - ScheduledFuture nodeUpdateTask = scheduledExecutorService.scheduleAtFixedRate(nodeAllocator::updateNodes, 5, 5, SECONDS); + NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(session); try { // root to children order @@ -1830,8 +1829,7 @@ public static FaultTolerantDistributedStagesScheduler create( queryStateMachine, schedulers.build(), schedulerStats, - nodeAllocator, - nodeUpdateTask); + nodeAllocator); } catch (Throwable t) { for (FaultTolerantStageScheduler scheduler : schedulers.build()) { @@ -1845,7 +1843,6 @@ public static FaultTolerantDistributedStagesScheduler create( } } - nodeUpdateTask.cancel(true); try { nodeAllocator.close(); } @@ -1933,15 +1930,13 @@ private FaultTolerantDistributedStagesScheduler( QueryStateMachine queryStateMachine, List schedulers, SplitSchedulerStats schedulerStats, - NodeAllocator nodeAllocator, - ScheduledFuture nodeUpdateTask) + NodeAllocator nodeAllocator) { this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.schedulers = requireNonNull(schedulers, "schedulers is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); - this.nodeUpdateTask = requireNonNull(nodeUpdateTask, "nodeUpdateTask is null"); } @Override @@ -2051,7 +2046,6 @@ public void abort() private void closeNodeAllocator() { - nodeUpdateTask.cancel(true); try { nodeAllocator.close(); } diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index ee5c0496f035..b002c719c3ab 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -60,6 +60,8 @@ import io.trino.execution.resourcegroups.InternalResourceGroupManager; import io.trino.execution.resourcegroups.LegacyResourceGroupConfigurationManager; import io.trino.execution.resourcegroups.ResourceGroupManager; +import io.trino.execution.scheduler.FixedCountNodeAllocatorService; +import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.StageTaskSourceFactory; import io.trino.execution.scheduler.TaskDescriptorStorage; @@ -209,6 +211,9 @@ protected void setup(Binder binder) bindLowMemoryKiller(LowMemoryKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES, TotalReservationOnBlockedNodesLowMemoryKiller.class); newExporter(binder).export(ClusterMemoryManager.class).withGeneratedName(); + // node allocator + binder.bind(NodeAllocatorService.class).to(FixedCountNodeAllocatorService.class).in(Scopes.SINGLETON); + // node monitor binder.bind(ClusterSizeMonitor.class).in(Scopes.SINGLETON); newExporter(binder).export(ClusterSizeMonitor.class).withGeneratedName(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index eb5bd3b9475d..775d49bb9e2a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -52,6 +52,7 @@ import io.trino.testing.TestingMetadata.TestingColumnHandle; import io.trino.util.FinalizerService; import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -103,6 +104,7 @@ public class TestFaultTolerantStageScheduler private FinalizerService finalizerService; private NodeTaskMap nodeTaskMap; + private FixedCountNodeAllocatorService nodeAllocatorService; @BeforeClass public void beforeClass() @@ -122,6 +124,21 @@ public void afterClass() } } + private void setupNodeAllocatorService(TestingNodeSupplier nodeSupplier) + { + shutdownNodeAllocatorService(); // just in case + nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier))); + } + + @AfterMethod(alwaysRun = true) + public void shutdownNodeAllocatorService() + { + if (nodeAllocatorService != null) { + nodeAllocatorService.stop(); + } + nodeAllocatorService = null; + } + @Test public void testHappyPath() throws Exception @@ -132,134 +149,137 @@ public void testHappyPath() NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG), NODE_3, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sinkExchange = new TestingExchange(false); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - createNodeAllocator(nodeSupplier), - TaskLifecycleListener.NO_OP, - Optional.of(sinkExchange), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2); - - ListenableFuture blocked = scheduler.isBlocked(); - assertUnblocked(blocked); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on first source exchange - assertBlocked(blocked); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - // still blocked on the second source exchange - assertBlocked(blocked); - assertBlocked(scheduler.isBlocked()); - - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - // now unblocked - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on node allocation - assertBlocked(blocked); - - // not all tasks have been enumerated yet - assertFalse(sinkExchange.isNoMoreSinks()); - - Map tasks = remoteTaskFactory.getTasks(); - // one task per node - assertThat(tasks).hasSize(3); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(1, 0)); - assertThat(tasks).containsKey(getTaskId(2, 0)); - - TestingRemoteTask task = tasks.get(getTaskId(0, 0)); - // fail task for partition 0 - task.fail(new RuntimeException("some failure")); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - // schedule more tasks - scheduler.schedule(); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(4); - assertThat(tasks).containsKey(getTaskId(3, 0)); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - // finish some task - assertThat(tasks).containsKey(getTaskId(1, 0)); - tasks.get(getTaskId(1, 0)).finish(); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); - - // this will schedule failed task - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(5); - assertThat(tasks).containsKey(getTaskId(0, 1)); - - // finish some task - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(3, 0)); - tasks.get(getTaskId(3, 0)).finish(); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.of(sinkExchange), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); - assertUnblocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + assertUnblocked(blocked); - // schedule the last task - scheduler.schedule(); + scheduler.schedule(); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(6); - assertThat(tasks).containsKey(getTaskId(4, 0)); + blocked = scheduler.isBlocked(); + // blocked on first source exchange + assertBlocked(blocked); + + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // still blocked on the second source exchange + assertBlocked(blocked); + assertFalse(scheduler.isBlocked().isDone()); + + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + // now unblocked + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on node allocation + assertBlocked(blocked); + + // not all tasks have been enumerated yet + assertFalse(sinkExchange.isNoMoreSinks()); + + Map tasks = remoteTaskFactory.getTasks(); + // one task per node + assertThat(tasks).hasSize(3); + assertThat(tasks).containsKey(getTaskId(0, 0)); + assertThat(tasks).containsKey(getTaskId(1, 0)); + assertThat(tasks).containsKey(getTaskId(2, 0)); + + TestingRemoteTask task = tasks.get(getTaskId(0, 0)); + // fail task for partition 0 + task.fail(new RuntimeException("some failure")); + + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + + // schedule more tasks + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(4); + assertThat(tasks).containsKey(getTaskId(3, 0)); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + // finish some task + assertThat(tasks).containsKey(getTaskId(1, 0)); + tasks.get(getTaskId(1, 0)).finish(); + + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); + + // this will schedule failed task + scheduler.schedule(); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(5); + assertThat(tasks).containsKey(getTaskId(0, 1)); + + // finish some task + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(3, 0)); + tasks.get(getTaskId(3, 0)).finish(); + assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); + + assertUnblocked(blocked); + + // schedule the last task + scheduler.schedule(); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).hasSize(6); + assertThat(tasks).containsKey(getTaskId(4, 0)); + + // not finished yet, will be finished when all tasks succeed + assertFalse(scheduler.isFinished()); + + blocked = scheduler.isBlocked(); + // blocked on task scheduling + assertBlocked(blocked); + + tasks = remoteTaskFactory.getTasks(); + assertThat(tasks).containsKey(getTaskId(4, 0)); + // finish remaining tasks + tasks.get(getTaskId(0, 1)).finish(); + tasks.get(getTaskId(2, 0)).finish(); + tasks.get(getTaskId(4, 0)).finish(); - // not finished yet, will be finished when all tasks succeed - assertFalse(scheduler.isFinished()); + // now it's not blocked and finished + assertUnblocked(blocked); + assertUnblocked(scheduler.isBlocked()); - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); + assertThat(sinkExchange.getFinishedSinkHandles()).contains( + new TestingExchangeSinkHandle(0), + new TestingExchangeSinkHandle(1), + new TestingExchangeSinkHandle(2), + new TestingExchangeSinkHandle(3), + new TestingExchangeSinkHandle(4)); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(4, 0)); - // finish remaining tasks - tasks.get(getTaskId(0, 1)).finish(); - tasks.get(getTaskId(2, 0)).finish(); - tasks.get(getTaskId(4, 0)).finish(); - - // now it's not blocked and finished - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - assertThat(sinkExchange.getFinishedSinkHandles()).contains( - new TestingExchangeSinkHandle(0), - new TestingExchangeSinkHandle(1), - new TestingExchangeSinkHandle(2), - new TestingExchangeSinkHandle(3), - new TestingExchangeSinkHandle(4)); - - assertTrue(scheduler.isFinished()); + assertTrue(scheduler.isFinished()); + } } @Test @@ -271,37 +291,40 @@ public void testTaskLifecycleListener() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingTaskLifecycleListener taskLifecycleListener = new TestingTaskLifecycleListener(); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - createNodeAllocator(nodeSupplier), - taskLifecycleListener, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + taskLifecycleListener, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 2); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); + scheduler.schedule(); + assertBlocked(scheduler.isBlocked()); - assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0)); + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0)); - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some exception")); + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some exception")); - assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); + assertUnblocked(scheduler.isBlocked()); + scheduler.schedule(); + assertBlocked(scheduler.isBlocked()); - assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0), getTaskId(0, 1)); + assertThat(taskLifecycleListener.getTasks().get(FRAGMENT_ID)).contains(getTaskId(0, 0), getTaskId(1, 0), getTaskId(0, 1)); + } } @Test @@ -313,46 +336,46 @@ public void testTaskFailure() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); + remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); - assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); - assertTrue(acquireNode1.isDone()); - assertTrue(acquireNode2.isDone()); + assertUnblocked(blocked); + assertUnblocked(acquireNode1); + assertUnblocked(acquireNode2); - assertThatThrownBy(scheduler::schedule) - .hasMessageContaining("some failure"); + assertThatThrownBy(scheduler::schedule) + .hasMessageContaining("some failure"); - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertFalse(scheduler.isFinished()); + } } @Test @@ -364,43 +387,45 @@ public void testReportTaskFailure() TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 1); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 1); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting for tasks to finish - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting for tasks to finish + assertBlocked(blocked); - scheduler.reportTaskFailure(getTaskId(0, 0), new RuntimeException("some failure")); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); + scheduler.reportTaskFailure(getTaskId(0, 0), new RuntimeException("some failure")); + assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertUnblocked(blocked); - scheduler.schedule(); + assertUnblocked(blocked); + scheduler.schedule(); - assertThat(remoteTaskFactory.getTasks()).containsKey(getTaskId(0, 1)); + assertThat(remoteTaskFactory.getTasks()).containsKey(getTaskId(0, 1)); - remoteTaskFactory.getTasks().get(getTaskId(0, 1)).finish(); - remoteTaskFactory.getTasks().get(getTaskId(1, 0)).finish(); + remoteTaskFactory.getTasks().get(getTaskId(0, 1)).finish(); + remoteTaskFactory.getTasks().get(getTaskId(1, 0)).finish(); - assertUnblocked(scheduler.isBlocked()); - assertTrue(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertTrue(scheduler.isFinished()); + } } @Test @@ -419,48 +444,50 @@ private void testCancellation(boolean abort) TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( NODE_1, ImmutableList.of(CATALOG), NODE_2, ImmutableList.of(CATALOG))); + setupNodeAllocatorService(nodeSupplier); TestingExchange sourceExchange1 = new TestingExchange(false); TestingExchange sourceExchange2 = new TestingExchange(false); - NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - TaskLifecycleListener.NO_OP, - Optional.empty(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0); + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { + FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( + remoteTaskFactory, + taskSourceFactory, + nodeAllocator, + TaskLifecycleListener.NO_OP, + Optional.empty(), + ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), + 0); - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); - assertUnblocked(scheduler.isBlocked()); + sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 1))); + assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); + scheduler.schedule(); - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); + ListenableFuture blocked = scheduler.isBlocked(); + // waiting on node acquisition + assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - if (abort) { - scheduler.abort(); - } - else { - scheduler.cancel(); - } + if (abort) { + scheduler.abort(); + } + else { + scheduler.cancel(); + } - assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); + assertUnblocked(blocked); + assertUnblocked(acquireNode1); + assertUnblocked(acquireNode2); - scheduler.schedule(); + scheduler.schedule(); - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); + assertUnblocked(scheduler.isBlocked()); + assertFalse(scheduler.isFinished()); + } } private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( @@ -562,12 +589,6 @@ private static List createSplits(int count) return ImmutableList.copyOf(limit(cycle(new Split(CATALOG, createRemoteSplit(), Lifespan.taskWide())), count)); } - private NodeAllocator createNodeAllocator(TestingNodeSupplier nodeSupplier) - { - NodeScheduler nodeScheduler = new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier)); - return new FixedCountNodeAllocator(nodeScheduler, SESSION, 1); - } - private static TaskId getTaskId(int partitionId, int attemptId) { return new TaskId(STAGE_ID, partitionId, attemptId); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java index 569d1ad420ce..fb1067b8de41 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -23,6 +23,7 @@ import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; import io.trino.metadata.InternalNode; import io.trino.spi.HostAddress; +import org.testng.annotations.AfterMethod; import org.testng.annotations.Test; import java.net.URI; @@ -35,6 +36,8 @@ import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; +// uses mutable state +@Test(singleThreaded = true) public class TestFixedCountNodeAllocator { private static final Session SESSION = testSessionBuilder().build(); @@ -50,13 +53,31 @@ public class TestFixedCountNodeAllocator private static final CatalogName CATALOG_1 = new CatalogName("catalog1"); private static final CatalogName CATALOG_2 = new CatalogName("catalog2"); + private FixedCountNodeAllocatorService nodeAllocatorService; + + private void setupNodeAllocatorService(TestingNodeSupplier testingNodeSupplier) + { + shutdownNodeAllocatorService(); // just in case + nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier))); + } + + @AfterMethod(alwaysRun = true) + public void shutdownNodeAllocatorService() + { + if (nodeAllocatorService != null) { + nodeAllocatorService.stop(); + } + nodeAllocatorService = null; + } + @Test public void testSingleNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -70,7 +91,7 @@ public void testSingleNode() assertEquals(acquire2.get(), NODE_1); } - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -100,8 +121,9 @@ public void testMultipleNodes() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -132,7 +154,7 @@ public void testMultipleNodes() assertEquals(acquire5.get(), NODE_1); } - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 2)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -189,7 +211,9 @@ public void testCatalogRequirement() NODE_2, ImmutableList.of(CATALOG_2), NODE_3, ImmutableList.of(CATALOG_1, CATALOG_2))); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + setupNodeAllocatorService(nodeSupplier); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertTrue(catalog1acquire1.isDone()); assertEquals(catalog1acquire1.get(), NODE_1); @@ -239,8 +263,9 @@ public void testCancellation() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -264,8 +289,9 @@ public void testAddNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -274,7 +300,7 @@ public void testAddNode() assertFalse(acquire2.isDone()); nodeSupplier.addNode(NODE_2, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertEquals(acquire2.get(10, SECONDS), NODE_2); } @@ -285,8 +311,9 @@ public void testRemoveNode() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); + setupNodeAllocatorService(nodeSupplier); - try (NodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_1); @@ -296,7 +323,7 @@ public void testRemoveNode() nodeSupplier.removeNode(NODE_1); nodeSupplier.addNode(NODE_2, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertEquals(acquire2.get(10, SECONDS), NODE_2); @@ -313,7 +340,9 @@ public void testAddressRequirement() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(), NODE_2, ImmutableList.of())); - try (FixedCountNodeAllocator nodeAllocator = createNodeAllocator(nodeSupplier, 1)) { + setupNodeAllocatorService(nodeSupplier); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); assertTrue(acquire1.isDone()); assertEquals(acquire1.get(), NODE_2); @@ -332,7 +361,7 @@ public void testAddressRequirement() .hasMessageContaining("No nodes available to run query"); nodeSupplier.addNode(NODE_3, ImmutableList.of()); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); assertTrue(acquire4.isDone()); @@ -342,7 +371,7 @@ public void testAddressRequirement() assertFalse(acquire5.isDone()); nodeSupplier.removeNode(NODE_3); - nodeAllocator.updateNodes(); + nodeAllocatorService.updateNodes(); assertTrue(acquire5.isDone()); assertThatThrownBy(acquire5::get) @@ -350,11 +379,6 @@ public void testAddressRequirement() } } - private FixedCountNodeAllocator createNodeAllocator(TestingNodeSupplier testingNodeSupplier, int maximumAllocationsPerNode) - { - return new FixedCountNodeAllocator(createNodeScheduler(testingNodeSupplier), SESSION, maximumAllocationsPerNode); - } - private NodeScheduler createNodeScheduler(TestingNodeSupplier testingNodeSupplier) { return new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier)); From 45f046c1f43c39ee8b5623ed349170e9c41a5601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Tue, 15 Feb 2022 14:18:18 +0100 Subject: [PATCH 05/11] Add NodeInfo wrapper for nodes allocated via NodeAllocator --- .../FaultTolerantStageScheduler.java | 13 +- .../FixedCountNodeAllocatorService.java | 23 +-- .../execution/scheduler/NodeAllocator.java | 5 +- .../trino/execution/scheduler/NodeInfo.java | 46 +++++ .../TestFaultTolerantStageScheduler.java | 8 +- .../TestFixedCountNodeAllocator.java | 186 +++++++++--------- 6 files changed, 163 insertions(+), 118 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/NodeInfo.java diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index eda402321959..9a05f50c49fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -34,7 +34,6 @@ import io.trino.execution.TaskStatus; import io.trino.execution.buffer.OutputBuffers; import io.trino.failuredetector.FailureDetector; -import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.spi.ErrorCode; import io.trino.spi.TrinoException; @@ -108,7 +107,7 @@ public class FaultTolerantStageScheduler private ListenableFuture blocked = immediateVoidFuture(); @GuardedBy("this") - private ListenableFuture acquireNodeFuture; + private ListenableFuture acquireNodeFuture; @GuardedBy("this") private SettableFuture taskFinishedFuture; @@ -121,7 +120,7 @@ public class FaultTolerantStageScheduler @GuardedBy("this") private final Map runningTasks = new HashMap<>(); @GuardedBy("this") - private final Map runningNodes = new HashMap<>(); + private final Map runningNodes = new HashMap<>(); @GuardedBy("this") private final Set allPartitions = new HashSet<>(); @GuardedBy("this") @@ -261,7 +260,7 @@ public synchronized void schedule() blocked = asVoid(acquireNodeFuture); return; } - InternalNode node = getFutureValue(acquireNodeFuture); + NodeInfo node = getFutureValue(acquireNodeFuture); acquireNodeFuture = null; queuedPartitions.poll(); @@ -300,7 +299,7 @@ public synchronized void schedule() .build(); RemoteTask task = stage.createTask( - node, + node.getNode(), partition, attemptId, sinkBucketToPartitionMap, @@ -403,7 +402,7 @@ private void cancelBlockedFuture() private void releaseAcquiredNode() { verify(!Thread.holdsLock(this)); - ListenableFuture future; + ListenableFuture future; synchronized (this) { future = acquireNodeFuture; acquireNodeFuture = null; @@ -498,7 +497,7 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional "node not found for task id: " + taskId); + NodeInfo node = requireNonNull(runningNodes.remove(taskId), () -> "node not found for task id: " + taskId); nodeAllocator.release(node); int partitionId = taskId.getPartitionId(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java index 2faf06de4b40..65fef12915eb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java @@ -47,6 +47,7 @@ import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.execution.scheduler.NodeInfo.unlimitedMemoryNode; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -147,19 +148,19 @@ public FixedCountNodeAllocator( } @Override - public synchronized ListenableFuture acquire(NodeRequirements requirements) + public synchronized ListenableFuture acquire(NodeRequirements requirements) { try { Optional node = tryAcquireNode(requirements); if (node.isPresent()) { - return immediateFuture(node.get()); + return immediateFuture(unlimitedMemoryNode(node.get())); } } catch (RuntimeException e) { return immediateFailedFuture(e); } - SettableFuture future = SettableFuture.create(); + SettableFuture future = SettableFuture.create(); PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); pendingAcquires.add(pendingAcquire); @@ -167,9 +168,9 @@ public synchronized ListenableFuture acquire(NodeRequirements requ } @Override - public void release(InternalNode node) + public void release(NodeInfo node) { - releaseNodeInternal(node); + releaseNodeInternal(node.getNode()); processPendingAcquires(); } @@ -243,15 +244,15 @@ private void processPendingAcquires() // set futures outside of critical section assignedNodes.forEach((pendingAcquire, node) -> { - SettableFuture future = pendingAcquire.getFuture(); - future.set(node); + SettableFuture future = pendingAcquire.getFuture(); + future.set(unlimitedMemoryNode(node)); if (future.isCancelled()) { releaseNodeInternal(node); } }); failures.forEach((pendingAcquire, failure) -> { - SettableFuture future = pendingAcquire.getFuture(); + SettableFuture future = pendingAcquire.getFuture(); future.setException(failure); }); } @@ -266,9 +267,9 @@ public synchronized void close() private static class PendingAcquire { private final NodeRequirements nodeRequirements; - private final SettableFuture future; + private final SettableFuture future; - private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) + private PendingAcquire(NodeRequirements nodeRequirements, SettableFuture future) { this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); this.future = requireNonNull(future, "future is null"); @@ -279,7 +280,7 @@ public NodeRequirements getNodeRequirements() return nodeRequirements; } - public SettableFuture getFuture() + public SettableFuture getFuture() { return future; } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java index f7aaa038f934..d4fadf168507 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java @@ -14,16 +14,15 @@ package io.trino.execution.scheduler; import com.google.common.util.concurrent.ListenableFuture; -import io.trino.metadata.InternalNode; import java.io.Closeable; public interface NodeAllocator extends Closeable { - ListenableFuture acquire(NodeRequirements requirements); + ListenableFuture acquire(NodeRequirements requirements); - void release(InternalNode node); + void release(NodeInfo node); @Override void close(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeInfo.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeInfo.java new file mode 100644 index 000000000000..200c039a23c3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeInfo.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.airlift.units.DataSize; +import io.trino.metadata.InternalNode; + +import static java.util.Objects.requireNonNull; + +public class NodeInfo +{ + private final InternalNode node; + private final DataSize maxMemory; + + public static NodeInfo unlimitedMemoryNode(InternalNode node) + { + return new NodeInfo(node, DataSize.ofBytes(Long.MAX_VALUE)); + } + + public NodeInfo(InternalNode node, DataSize maxMemory) + { + this.node = requireNonNull(node, "node is null"); + this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); + } + + public InternalNode getNode() + { + return node; + } + + public DataSize getMaxMemory() + { + return maxMemory; + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index 775d49bb9e2a..fb59d5f75596 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -361,8 +361,8 @@ public void testTaskFailure() // waiting on node acquisition assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); @@ -469,8 +469,8 @@ private void testCancellation(boolean abort) // waiting on node acquisition assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); if (abort) { scheduler.abort(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java index fb1067b8de41..c8ab0f3a3a6e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -78,41 +78,41 @@ public void testSingleNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire2.isDone()); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire1.get()); assertTrue(acquire2.isDone()); - assertEquals(acquire2.get(), NODE_1); + assertEquals(acquire2.get().getNode(), NODE_1); } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire2.isDone()); - assertEquals(acquire2.get(), NODE_1); + assertEquals(acquire2.get().getNode(), NODE_1); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire3.isDone()); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire4.isDone()); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire2.get()); // NODE_1 assertTrue(acquire3.isDone()); - assertEquals(acquire3.get(), NODE_1); + assertEquals(acquire3.get().getNode(), NODE_1); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire3.get()); // NODE_1 assertTrue(acquire4.isDone()); - assertEquals(acquire4.get(), NODE_1); + assertEquals(acquire4.get().getNode(), NODE_1); } } @@ -124,81 +124,81 @@ public void testMultipleNodes() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire2.isDone()); - assertEquals(acquire2.get(), NODE_2); + assertEquals(acquire2.get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire3.isDone()); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire4.isDone()); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire5.isDone()); - nodeAllocator.release(NODE_2); + nodeAllocator.release(acquire2.get()); // NODE_2 assertTrue(acquire3.isDone()); - assertEquals(acquire3.get(), NODE_2); + assertEquals(acquire3.get().getNode(), NODE_2); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire1.get()); // NODE_1 assertTrue(acquire4.isDone()); - assertEquals(acquire4.get(), NODE_1); + assertEquals(acquire4.get().getNode(), NODE_1); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire4.get()); // NODE_1 assertTrue(acquire5.isDone()); - assertEquals(acquire5.get(), NODE_1); + assertEquals(acquire5.get().getNode(), NODE_1); } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire2.isDone()); - assertEquals(acquire2.get(), NODE_2); + assertEquals(acquire2.get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire3.isDone()); - assertEquals(acquire3.get(), NODE_1); + assertEquals(acquire3.get().getNode(), NODE_1); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire4.isDone()); - assertEquals(acquire4.get(), NODE_2); + assertEquals(acquire4.get().getNode(), NODE_2); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire5.isDone()); - ListenableFuture acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire6.isDone()); - nodeAllocator.release(NODE_2); + nodeAllocator.release(acquire2.get()); // NODE_2 assertTrue(acquire5.isDone()); - assertEquals(acquire5.get(), NODE_2); + assertEquals(acquire5.get().getNode(), NODE_2); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire1.get()); // NODE_1 assertTrue(acquire6.isDone()); - assertEquals(acquire6.get(), NODE_1); + assertEquals(acquire6.get().getNode(), NODE_1); - ListenableFuture acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire7.isDone()); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire3.get()); // NODE_1 assertTrue(acquire7.isDone()); - assertEquals(acquire7.get(), NODE_1); + assertEquals(acquire7.get().getNode(), NODE_1); - nodeAllocator.release(NODE_1); - nodeAllocator.release(NODE_2); - nodeAllocator.release(NODE_2); + nodeAllocator.release(acquire6.get()); // NODE_1 + nodeAllocator.release(acquire5.get()); // NODE_2 + nodeAllocator.release(acquire4.get()); // NODE_2 - ListenableFuture acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire8.isDone()); - assertEquals(acquire8.get(), NODE_2); + assertEquals(acquire8.get().getNode(), NODE_2); } } @@ -214,47 +214,47 @@ public void testCatalogRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertTrue(catalog1acquire1.isDone()); - assertEquals(catalog1acquire1.get(), NODE_1); + assertEquals(catalog1acquire1.get().getNode(), NODE_1); - ListenableFuture catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + ListenableFuture catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertTrue(catalog1acquire2.isDone()); - assertEquals(catalog1acquire2.get(), NODE_3); + assertEquals(catalog1acquire2.get().getNode(), NODE_3); - ListenableFuture catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + ListenableFuture catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertFalse(catalog1acquire3.isDone()); - ListenableFuture catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + ListenableFuture catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); assertTrue(catalog2acquire1.isDone()); - assertEquals(catalog2acquire1.get(), NODE_2); + assertEquals(catalog2acquire1.get().getNode(), NODE_2); - ListenableFuture catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + ListenableFuture catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); assertFalse(catalog2acquire2.isDone()); - nodeAllocator.release(NODE_2); + nodeAllocator.release(catalog2acquire1.get()); // NODE_2 assertFalse(catalog1acquire3.isDone()); assertTrue(catalog2acquire2.isDone()); - assertEquals(catalog2acquire2.get(), NODE_2); + assertEquals(catalog2acquire2.get().getNode(), NODE_2); - nodeAllocator.release(NODE_1); + nodeAllocator.release(catalog1acquire1.get()); // NODE_1 assertTrue(catalog1acquire3.isDone()); - assertEquals(catalog1acquire3.get(), NODE_1); + assertEquals(catalog1acquire3.get().getNode(), NODE_1); - ListenableFuture catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + ListenableFuture catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); assertFalse(catalog1acquire4.isDone()); - ListenableFuture catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + ListenableFuture catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); assertFalse(catalog2acquire4.isDone()); - nodeAllocator.release(NODE_3); + nodeAllocator.release(catalog1acquire2.get()); // NODE_3 assertFalse(catalog2acquire4.isDone()); assertTrue(catalog1acquire4.isDone()); - assertEquals(catalog1acquire4.get(), NODE_3); + assertEquals(catalog1acquire4.get().getNode(), NODE_3); - nodeAllocator.release(NODE_3); + nodeAllocator.release(catalog1acquire4.get()); // NODE_3 assertTrue(catalog2acquire4.isDone()); - assertEquals(catalog2acquire4.get(), NODE_3); + assertEquals(catalog2acquire4.get().getNode(), NODE_3); } } @@ -266,21 +266,21 @@ public void testCancellation() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire2.isDone()); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire3.isDone()); acquire2.cancel(true); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire1.get()); // NODE_1 assertTrue(acquire3.isDone()); - assertEquals(acquire3.get(), NODE_1); + assertEquals(acquire3.get().getNode(), NODE_1); } } @@ -292,17 +292,17 @@ public void testAddNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire2.isDone()); nodeSupplier.addNode(NODE_2, ImmutableList.of()); nodeAllocatorService.updateNodes(); - assertEquals(acquire2.get(10, SECONDS), NODE_2); + assertEquals(acquire2.get(10, SECONDS).getNode(), NODE_2); } } @@ -314,23 +314,23 @@ public void testRemoveNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_1); + assertEquals(acquire1.get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire2.isDone()); nodeSupplier.removeNode(NODE_1); nodeSupplier.addNode(NODE_2, ImmutableList.of()); nodeAllocatorService.updateNodes(); - assertEquals(acquire2.get(10, SECONDS), NODE_2); + assertEquals(acquire2.get(10, SECONDS).getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); assertFalse(acquire3.isDone()); - nodeAllocator.release(NODE_1); + nodeAllocator.release(acquire1.get()); // NODE_1 assertFalse(acquire3.isDone()); } } @@ -343,19 +343,19 @@ public void testAddressRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); assertTrue(acquire1.isDone()); - assertEquals(acquire1.get(), NODE_2); + assertEquals(acquire1.get().getNode(), NODE_2); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); assertFalse(acquire2.isDone()); - nodeAllocator.release(NODE_2); + nodeAllocator.release(acquire1.get()); // NODE_2 assertTrue(acquire2.isDone()); - assertEquals(acquire2.get(), NODE_2); + assertEquals(acquire2.get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); assertTrue(acquire3.isDone()); assertThatThrownBy(acquire3::get) .hasMessageContaining("No nodes available to run query"); @@ -363,11 +363,11 @@ public void testAddressRequirement() nodeSupplier.addNode(NODE_3, ImmutableList.of()); nodeAllocatorService.updateNodes(); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); assertTrue(acquire4.isDone()); - assertEquals(acquire4.get(), NODE_3); + assertEquals(acquire4.get().getNode(), NODE_3); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); assertFalse(acquire5.isDone()); nodeSupplier.removeNode(NODE_3); From 83b7c62ecab8480bd548e7a49024e939d956dcc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Wed, 2 Mar 2022 14:25:14 -0800 Subject: [PATCH 06/11] Drop mechanism of per memory task limit It turned out we do not need that functionality after all for task-level retries. Removing as currently we do not see the benefit of the mechanism and it increases complexity. --- .../trino/ExceededMemoryLimitException.java | 6 - .../io/trino/SystemSessionProperties.java | 11 -- .../io/trino/execution/SqlTaskManager.java | 16 +-- .../io/trino/memory/NodeMemoryConfig.java | 22 +--- .../java/io/trino/memory/QueryContext.java | 14 +- .../operator/TaskAllocationValidator.java | 88 ------------- .../java/io/trino/operator/TaskContext.java | 24 +--- .../io/trino/testing/TestingTaskContext.java | 2 - .../execution/MockRemoteTaskFactory.java | 1 - .../TestMemoryRevokingScheduler.java | 2 - .../java/io/trino/execution/TestSqlTask.java | 1 - .../trino/execution/TestSqlTaskExecution.java | 1 - .../java/io/trino/memory/TestMemoryPools.java | 2 - .../io/trino/memory/TestMemoryTracking.java | 22 ---- .../io/trino/memory/TestNodeMemoryConfig.java | 4 - .../io/trino/memory/TestQueryContext.java | 2 - .../operator/GroupByHashYieldAssertion.java | 2 - .../java/io/trino/spi/StandardErrorCode.java | 3 +- .../admin/properties-resource-management.rst | 10 -- lib/trino-memory-context/pom.xml | 5 - .../context/MemoryAllocationValidator.java | 38 ------ .../context/ValidatingAggregateContext.java | 53 -------- .../context/ValidatingLocalMemoryContext.java | 108 --------------- .../memory/context/TestMemoryContexts.java | 123 ------------------ .../benchmark/AbstractOperatorBenchmark.java | 1 - .../benchmark/MemoryLocalQueryRunner.java | 1 - 26 files changed, 10 insertions(+), 552 deletions(-) delete mode 100644 core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java delete mode 100644 lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java delete mode 100644 lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java delete mode 100644 lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java diff --git a/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java b/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java index b0f57af45bcd..6b95473610b4 100644 --- a/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java +++ b/core/trino-main/src/main/java/io/trino/ExceededMemoryLimitException.java @@ -40,12 +40,6 @@ public static ExceededMemoryLimitException exceededLocalUserMemoryLimit(DataSize format("Query exceeded per-node memory limit of %s [%s]", maxMemory, additionalFailureInfo)); } - public static ExceededMemoryLimitException exceededTaskMemoryLimit(DataSize maxMemory, String additionalFailureInfo) - { - return new ExceededMemoryLimitException(EXCEEDED_LOCAL_MEMORY_LIMIT, - format("Query exceeded per-task memory limit of %s [%s]", maxMemory, additionalFailureInfo)); - } - private ExceededMemoryLimitException(StandardErrorCode errorCode, String message) { super(errorCode, message); diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 998fbe04ce77..8368cf1e2716 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -130,7 +130,6 @@ public final class SystemSessionProperties public static final String ENABLE_COORDINATOR_DYNAMIC_FILTERS_DISTRIBUTION = "enable_coordinator_dynamic_filters_distribution"; public static final String ENABLE_LARGE_DYNAMIC_FILTERS = "enable_large_dynamic_filters"; public static final String QUERY_MAX_MEMORY_PER_NODE = "query_max_memory_per_node"; - public static final String QUERY_MAX_MEMORY_PER_TASK = "query_max_memory_per_task"; public static final String IGNORE_DOWNSTREAM_PREFERENCES = "ignore_downstream_preferences"; public static final String FILTERING_SEMI_JOIN_TO_INNER = "rewrite_filtering_semi_join_to_inner_join"; public static final String OPTIMIZE_DUPLICATE_INSENSITIVE_JOINS = "optimize_duplicate_insensitive_joins"; @@ -597,11 +596,6 @@ public SystemSessionProperties( "Maximum amount of memory a query can use per node", nodeMemoryConfig.getMaxQueryMemoryPerNode(), true), - dataSizeProperty( - QUERY_MAX_MEMORY_PER_TASK, - "Maximum amount of memory a single task can use", - nodeMemoryConfig.getMaxQueryMemoryPerTask().orElse(null), - true), booleanProperty( IGNORE_DOWNSTREAM_PREFERENCES, "Ignore Parent's PreferredProperties in AddExchange optimizer", @@ -1178,11 +1172,6 @@ public static DataSize getQueryMaxMemoryPerNode(Session session) return session.getSystemProperty(QUERY_MAX_MEMORY_PER_NODE, DataSize.class); } - public static Optional getQueryMaxMemoryPerTask(Session session) - { - return Optional.ofNullable(session.getSystemProperty(QUERY_MAX_MEMORY_PER_TASK, DataSize.class)); - } - public static boolean ignoreDownStreamPreferences(Session session) { return session.getSystemProperty(IGNORE_DOWNSTREAM_PREFERENCES, Boolean.class); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index 0beb3ac9c66d..5fb155a58dc6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -71,7 +71,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.Threads.threadsNamed; import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode; -import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerTask; import static io.trino.SystemSessionProperties.resourceOvercommit; import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; import static io.trino.execution.SqlTask.createSqlTask; @@ -105,7 +104,6 @@ public class SqlTaskManager private final SqlTaskIoStats finishedTaskStats = new SqlTaskIoStats(); private final long queryMaxMemoryPerNode; - private final Optional queryMaxMemoryPerTask; private final CounterStat failedTasks = new CounterStat(); @@ -144,13 +142,12 @@ public SqlTaskManager( SqlTaskExecutionFactory sqlTaskExecutionFactory = new SqlTaskExecutionFactory(taskNotificationExecutor, taskExecutor, planner, splitMonitor, config); DataSize maxQueryMemoryPerNode = nodeMemoryConfig.getMaxQueryMemoryPerNode(); - queryMaxMemoryPerTask = nodeMemoryConfig.getMaxQueryMemoryPerTask(); DataSize maxQuerySpillPerNode = nodeSpillConfig.getQueryMaxSpillPerNode(); queryMaxMemoryPerNode = maxQueryMemoryPerNode.toBytes(); queryContexts = buildNonEvictableCache(CacheBuilder.newBuilder().weakValues(), CacheLoader.from( - queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, queryMaxMemoryPerTask, maxQuerySpillPerNode))); + queryId -> createQueryContext(queryId, localMemoryManager, localSpillManager, gcMonitor, maxQueryMemoryPerNode, maxQuerySpillPerNode))); tasks = buildNonEvictableCache(CacheBuilder.newBuilder(), CacheLoader.from( taskId -> createSqlTask( @@ -173,13 +170,11 @@ private QueryContext createQueryContext( LocalSpillManager localSpillManager, GcMonitor gcMonitor, DataSize maxQueryUserMemoryPerNode, - Optional maxQueryMemoryPerTask, DataSize maxQuerySpillPerNode) { return new QueryContext( queryId, maxQueryUserMemoryPerNode, - maxQueryMemoryPerTask, localMemoryManager.getMemoryPool(), gcMonitor, taskNotificationExecutor, @@ -401,17 +396,10 @@ private TaskInfo doUpdateTask( if (!queryContext.isMemoryLimitsInitialized()) { long sessionQueryMaxMemoryPerNode = getQueryMaxMemoryPerNode(session).toBytes(); - Optional effectiveQueryMaxMemoryPerTask = getQueryMaxMemoryPerTask(session); - if (queryMaxMemoryPerTask.isPresent() && - (effectiveQueryMaxMemoryPerTask.isEmpty() || effectiveQueryMaxMemoryPerTask.get().toBytes() > queryMaxMemoryPerTask.get().toBytes())) { - effectiveQueryMaxMemoryPerTask = queryMaxMemoryPerTask; - } - // Session properties are only allowed to decrease memory limits, not increase them queryContext.initializeMemoryLimits( resourceOvercommit(session), - min(sessionQueryMaxMemoryPerNode, queryMaxMemoryPerNode), - effectiveQueryMaxMemoryPerTask); + min(sessionQueryMaxMemoryPerNode, queryMaxMemoryPerNode)); } sqlTask.recordHeartbeat(); diff --git a/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java b/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java index 7e0ef24df113..a7d54dacf3a7 100644 --- a/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java +++ b/core/trino-main/src/main/java/io/trino/memory/NodeMemoryConfig.java @@ -16,30 +16,25 @@ import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; import io.airlift.configuration.DefunctConfig; -import io.airlift.configuration.LegacyConfig; import io.airlift.units.DataSize; import javax.validation.constraints.NotNull; -import java.util.Optional; - // This is separate from MemoryManagerConfig because it's difficult to test the default value of maxQueryMemoryPerNode @DefunctConfig({ "deprecated.legacy-system-pool-enabled", "experimental.reserved-pool-disabled", "experimental.reserved-pool-enabled", "query.max-total-memory-per-node", + "query.max-memory-per-task" }) public class NodeMemoryConfig { public static final long AVAILABLE_HEAP_MEMORY = Runtime.getRuntime().maxMemory(); public static final String QUERY_MAX_MEMORY_PER_NODE_CONFIG = "query.max-memory-per-node"; - public static final String QUERY_MAX_MEMORY_PER_TASK_CONFIG = "query.max-memory-per-task"; private DataSize maxQueryMemoryPerNode = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3)); - private Optional maxQueryMemoryPerTask = Optional.empty(); - private DataSize heapHeadroom = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3)); @NotNull @@ -55,21 +50,6 @@ public NodeMemoryConfig setMaxQueryMemoryPerNode(DataSize maxQueryMemoryPerNode) return this; } - @NotNull - public Optional getMaxQueryMemoryPerTask() - { - return maxQueryMemoryPerTask; - } - - @Config(QUERY_MAX_MEMORY_PER_TASK_CONFIG) - @LegacyConfig("query.max-total-memory-per-task") - @ConfigDescription("Sets memory limit enforced for a single task; there is no memory limit by default") - public NodeMemoryConfig setMaxQueryMemoryPerTask(DataSize maxQueryMemoryPerTask) - { - this.maxQueryMemoryPerTask = Optional.ofNullable(maxQueryMemoryPerTask); - return this; - } - @NotNull public DataSize getHeapHeadroom() { diff --git a/core/trino-main/src/main/java/io/trino/memory/QueryContext.java b/core/trino-main/src/main/java/io/trino/memory/QueryContext.java index 3f374de27b57..778b1fb2671e 100644 --- a/core/trino-main/src/main/java/io/trino/memory/QueryContext.java +++ b/core/trino-main/src/main/java/io/trino/memory/QueryContext.java @@ -33,7 +33,6 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -74,8 +73,6 @@ public class QueryContext // TODO: This field should be final. However, due to the way QueryContext is constructed the memory limit is not known in advance @GuardedBy("this") private long maxUserMemory; - @GuardedBy("this") - private Optional maxTaskMemory; private final MemoryTrackingContext queryMemoryContext; private final MemoryPool memoryPool; @@ -86,7 +83,6 @@ public class QueryContext public QueryContext( QueryId queryId, DataSize maxUserMemory, - Optional maxTaskMemory, MemoryPool memoryPool, GcMonitor gcMonitor, Executor notificationExecutor, @@ -97,7 +93,6 @@ public QueryContext( this( queryId, maxUserMemory, - maxTaskMemory, memoryPool, GUARANTEED_MEMORY, gcMonitor, @@ -110,7 +105,6 @@ public QueryContext( public QueryContext( QueryId queryId, DataSize maxUserMemory, - Optional maxTaskMemory, MemoryPool memoryPool, long guaranteedMemory, GcMonitor gcMonitor, @@ -121,7 +115,6 @@ public QueryContext( { this.queryId = requireNonNull(queryId, "queryId is null"); this.maxUserMemory = requireNonNull(maxUserMemory, "maxUserMemory is null").toBytes(); - this.maxTaskMemory = requireNonNull(maxTaskMemory, "maxTaskMemory is null"); this.memoryPool = requireNonNull(memoryPool, "memoryPool is null"); this.gcMonitor = requireNonNull(gcMonitor, "gcMonitor is null"); this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); @@ -139,18 +132,16 @@ public boolean isMemoryLimitsInitialized() } // TODO: This method should be removed, and the correct limit set in the constructor. However, due to the way QueryContext is constructed the memory limit is not known in advance - public synchronized void initializeMemoryLimits(boolean resourceOverCommit, long maxUserMemory, Optional maxTaskMemory) + public synchronized void initializeMemoryLimits(boolean resourceOverCommit, long maxUserMemory) { checkArgument(maxUserMemory >= 0, "maxUserMemory must be >= 0, found: %s", maxUserMemory); if (resourceOverCommit) { // Allow the query to use the entire pool. This way the worker will kill the query, if it uses the entire local memory pool. // The coordinator will kill the query if the cluster runs out of memory. this.maxUserMemory = memoryPool.getMaxBytes(); - this.maxTaskMemory = Optional.empty(); // disabled } else { this.maxUserMemory = maxUserMemory; - this.maxTaskMemory = maxTaskMemory; } memoryLimitsInitialized = true; } @@ -260,8 +251,7 @@ public TaskContext addTaskContext( queryMemoryContext.newMemoryTrackingContext(), notifyStatusChanged, perOperatorCpuTimerEnabled, - cpuTimerEnabled, - maxTaskMemory); + cpuTimerEnabled); taskContexts.put(taskStateMachine.getTaskId(), taskContext); return taskContext; } diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java b/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java deleted file mode 100644 index 41edc9c6aa55..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/TaskAllocationValidator.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator; - -import io.airlift.units.DataSize; -import io.trino.memory.context.MemoryAllocationValidator; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - -import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; - -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.airlift.units.DataSize.succinctBytes; -import static io.trino.ExceededMemoryLimitException.exceededTaskMemoryLimit; -import static java.lang.String.format; -import static java.util.Map.Entry.comparingByValue; -import static java.util.Objects.requireNonNull; - -@ThreadSafe -// Keeps track of per-node memory usage of given task. Single instance is shared by multiple ValidatingLocalMemoryContext instances -// originating from single ValidatingAggregateContext. -public class TaskAllocationValidator - implements MemoryAllocationValidator -{ - private final long limitBytes; - @GuardedBy("this") - private long usedBytes; - @GuardedBy("this") - private final Map taggedAllocations = new HashMap<>(); - - public TaskAllocationValidator(DataSize memoryLimit) - { - this.limitBytes = requireNonNull(memoryLimit, "memoryLimit is null").toBytes(); - } - - @Override - public synchronized void reserveMemory(String allocationTag, long delta) - { - if (usedBytes + delta > limitBytes) { - verify(delta > 0, "exceeded limit with negative delta (%s); usedBytes=%s, limitBytes=%s", delta, usedBytes, limitBytes); - raiseLimitExceededFailure(allocationTag, delta); - } - usedBytes += delta; - taggedAllocations.merge(allocationTag, delta, Long::sum); - } - - private synchronized void raiseLimitExceededFailure(String currentAllocationTag, long currentAllocationDelta) - { - Map tmpTaggedAllocations = new HashMap<>(taggedAllocations); - // include current allocation in the output of top-consumers - tmpTaggedAllocations.merge(currentAllocationTag, currentAllocationDelta, Long::sum); - String topConsumers = tmpTaggedAllocations.entrySet().stream() - .sorted(comparingByValue(Comparator.reverseOrder())) - .limit(3) - .filter(e -> e.getValue() >= 0) - .collect(toImmutableMap(Map.Entry::getKey, e -> succinctBytes(e.getValue()))) - .toString(); - - String additionalInfo = format("Allocated: %s, Delta: %s, Top Consumers: %s", succinctBytes(usedBytes), succinctBytes(currentAllocationDelta), topConsumers); - throw exceededTaskMemoryLimit(DataSize.succinctBytes(limitBytes), additionalInfo); - } - - @Override - public synchronized boolean tryReserveMemory(String allocationTag, long delta) - { - if (usedBytes + delta > limitBytes) { - verify(delta > 0, "exceeded limit with negative delta (%s); usedBytes=%s, limitBytes=%s", delta, usedBytes, limitBytes); - return false; - } - usedBytes += delta; - return true; - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java index d24964f1b305..ed5005aed39c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java @@ -33,9 +33,7 @@ import io.trino.memory.QueryContext; import io.trino.memory.QueryContextVisitor; import io.trino.memory.context.LocalMemoryContext; -import io.trino.memory.context.MemoryAllocationValidator; import io.trino.memory.context.MemoryTrackingContext; -import io.trino.memory.context.ValidatingAggregateContext; import io.trino.spi.predicate.Domain; import io.trino.sql.planner.LocalDynamicFiltersCollector; import io.trino.sql.planner.plan.DynamicFilterId; @@ -46,7 +44,6 @@ import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; @@ -122,8 +119,7 @@ public static TaskContext createTaskContext( MemoryTrackingContext taskMemoryContext, Runnable notifyStatusChanged, boolean perOperatorCpuTimerEnabled, - boolean cpuTimerEnabled, - Optional maxMemory) + boolean cpuTimerEnabled) { TaskContext taskContext = new TaskContext( queryContext, @@ -135,8 +131,7 @@ public static TaskContext createTaskContext( taskMemoryContext, notifyStatusChanged, perOperatorCpuTimerEnabled, - cpuTimerEnabled, - maxMemory); + cpuTimerEnabled); taskContext.initialize(); return taskContext; } @@ -151,8 +146,7 @@ private TaskContext( MemoryTrackingContext taskMemoryContext, Runnable notifyStatusChanged, boolean perOperatorCpuTimerEnabled, - boolean cpuTimerEnabled, - Optional maxMemory) + boolean cpuTimerEnabled) { this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null"); this.gcMonitor = requireNonNull(gcMonitor, "gcMonitor is null"); @@ -160,17 +154,7 @@ private TaskContext( this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); this.yieldExecutor = requireNonNull(yieldExecutor, "yieldExecutor is null"); this.session = session; - - requireNonNull(taskMemoryContext, "taskMemoryContext is null"); - if (maxMemory.isPresent()) { - MemoryAllocationValidator memoryValidator = new TaskAllocationValidator(maxMemory.get()); - this.taskMemoryContext = new MemoryTrackingContext( - new ValidatingAggregateContext(taskMemoryContext.aggregateUserMemoryContext(), memoryValidator), - taskMemoryContext.aggregateRevocableMemoryContext()); - } - else { - this.taskMemoryContext = taskMemoryContext; - } + this.taskMemoryContext = requireNonNull(taskMemoryContext, "taskMemoryContext is null"); // Initialize the local memory contexts with the LazyOutputBuffer tag as LazyOutputBuffer will do the local memory allocations this.taskMemoryContext.initializeLocalMemoryContexts(LazyOutputBuffer.class.getSimpleName()); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java index 6d3b33162a9e..6ab3dfdea00b 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingTaskContext.java @@ -26,7 +26,6 @@ import io.trino.spi.QueryId; import io.trino.spiller.SpillSpaceTracker; -import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -142,7 +141,6 @@ public TaskContext build() QueryContext queryContext = new QueryContext( queryId, queryMaxMemory, - Optional.empty(), memoryPool, 0L, GC_MONITOR, diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 2108fe6e2f57..fb1e31d69e34 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -196,7 +196,6 @@ public MockRemoteTask( SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(DataSize.of(1, GIGABYTE)); QueryContext queryContext = new QueryContext(taskId.getQueryId(), DataSize.of(1, MEGABYTE), - Optional.empty(), memoryPool, new TestingGcMonitor(), executor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java index c54a5614a920..31932e6c88e2 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestMemoryRevokingScheduler.java @@ -44,7 +44,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ScheduledExecutorService; @@ -278,7 +277,6 @@ private QueryContext getOrCreateQueryContext(QueryId queryId) { return queryContexts.computeIfAbsent(queryId, id -> new QueryContext(id, DataSize.of(1, MEGABYTE), - Optional.empty(), memoryPool, new TestingGcMonitor(), executor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index f6db181ccf07..d300aa7e516a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -346,7 +346,6 @@ private SqlTask createInitialTask() QueryContext queryContext = new QueryContext(new QueryId("query"), DataSize.of(1, MEGABYTE), - Optional.empty(), new MemoryPool(DataSize.of(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 2e6f91ec2728..d9402ccfa50c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -596,7 +596,6 @@ private TaskContext newTestingTaskContext(ScheduledExecutorService taskNotificat QueryContext queryContext = new QueryContext( new QueryId("queryid"), DataSize.of(1, MEGABYTE), - Optional.empty(), new MemoryPool(DataSize.of(1, GIGABYTE)), new TestingGcMonitor(), taskNotificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java index da190489cc97..cb6b1fd05761 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryPools.java @@ -40,7 +40,6 @@ import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -94,7 +93,6 @@ private void setUp(Supplier> driversSupplier) SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(DataSize.of(1, GIGABYTE)); QueryContext queryContext = new QueryContext(new QueryId("query"), TEN_MEGABYTES, - Optional.empty(), userPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java index 331b5af2705b..8e2dc2744227 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryTracking.java @@ -38,7 +38,6 @@ import org.testng.annotations.Test; import java.util.List; -import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -93,17 +92,11 @@ public void tearDown() @BeforeMethod public void setUpTest() - { - setupTestWithLimits(queryMaxMemory, Optional.empty()); - } - - private void setupTestWithLimits(DataSize queryMaxMemory, Optional queryMaxTaskMemory) { memoryPool = new MemoryPool(memoryPoolSize); queryContext = new QueryContext( new QueryId("test_query"), queryMaxMemory, - queryMaxTaskMemory, memoryPool, new TestingGcMonitor(), notificationExecutor, @@ -156,21 +149,6 @@ public void testLocalTotalMemoryLimitExceeded() .hasMessage("Query exceeded per-node memory limit of %1$s [Allocated: %1$s, Delta: 1B, Top Consumers: {test=%1$s}]", queryMaxMemory); } - @Test - public void testTaskMemoryLimitExceeded() - { - DataSize taskMaxMemory = DataSize.of(1, GIGABYTE); - setupTestWithLimits(DataSize.of(2, GIGABYTE), Optional.of(taskMaxMemory)); - LocalMemoryContext memoryContext = operatorContext.newLocalUserMemoryContext("test"); - memoryContext.setBytes(100); - assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), 100, 0); - memoryContext.setBytes(taskMaxMemory.toBytes()); - assertOperatorMemoryAllocations(operatorContext.getOperatorMemoryContext(), taskMaxMemory.toBytes(), 0); - assertThatThrownBy(() -> memoryContext.setBytes(taskMaxMemory.toBytes() + 1)) - .isInstanceOf(ExceededMemoryLimitException.class) - .hasMessage("Query exceeded per-task memory limit of %1$s [Allocated: %s, Delta: 1B, Top Consumers: {test=%s}]", taskMaxMemory, DataSize.succinctBytes(taskMaxMemory.toBytes() + 1)); - } - @Test public void testLocalAllocations() { diff --git a/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java b/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java index 87499c756ee6..9e1177efb0e8 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestNodeMemoryConfig.java @@ -23,7 +23,6 @@ import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.memory.NodeMemoryConfig.AVAILABLE_HEAP_MEMORY; public class TestNodeMemoryConfig @@ -33,7 +32,6 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(NodeMemoryConfig.class) .setMaxQueryMemoryPerNode(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3))) - .setMaxQueryMemoryPerTask(null) .setHeapHeadroom(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.3)))); } @@ -42,13 +40,11 @@ public void testExplicitPropertyMappings() { Map properties = ImmutableMap.builder() .put("query.max-memory-per-node", "1GB") - .put("query.max-memory-per-task", "200MB") .put("memory.heap-headroom-per-node", "1GB") .buildOrThrow(); NodeMemoryConfig expected = new NodeMemoryConfig() .setMaxQueryMemoryPerNode(DataSize.of(1, GIGABYTE)) - .setMaxQueryMemoryPerTask(DataSize.of(200, MEGABYTE)) .setHeapHeadroom(DataSize.of(1, GIGABYTE)); assertFullMapping(properties, expected); diff --git a/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java b/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java index 84e96f13ab13..122117ef056c 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestQueryContext.java @@ -22,7 +22,6 @@ import org.testng.annotations.AfterClass; import org.testng.annotations.Test; -import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; import static io.airlift.concurrent.Threads.threadsNamed; @@ -47,7 +46,6 @@ public void testSetMemoryPool() QueryContext queryContext = new QueryContext( new QueryId("query"), DataSize.ofBytes(10), - Optional.empty(), new MemoryPool(DataSize.ofBytes(10)), new TestingGcMonitor(), localQueryRunner.getExecutor(), diff --git a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java index 3165b6647886..681e9d44eb02 100644 --- a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java +++ b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java @@ -26,7 +26,6 @@ import java.util.LinkedList; import java.util.List; -import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Function; @@ -82,7 +81,6 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< QueryContext queryContext = new QueryContext( queryId, DataSize.of(512, MEGABYTE), - Optional.empty(), memoryPool, new TestingGcMonitor(), EXECUTOR, diff --git a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java index 20c45fa9c4cc..59ae471582b4 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +++ b/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java @@ -170,8 +170,7 @@ public enum StandardErrorCode EXCEEDED_LOCAL_MEMORY_LIMIT(131079, INSUFFICIENT_RESOURCES), ADMINISTRATIVELY_PREEMPTED(131080, INSUFFICIENT_RESOURCES), EXCEEDED_SCAN_LIMIT(131081, INSUFFICIENT_RESOURCES), - EXCEEDED_TASK_MEMORY_LIMIT(131082, INSUFFICIENT_RESOURCES), - EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY(131083, INSUFFICIENT_RESOURCES), + EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY(131082, INSUFFICIENT_RESOURCES), /**/; diff --git a/docs/src/main/sphinx/admin/properties-resource-management.rst b/docs/src/main/sphinx/admin/properties-resource-management.rst index f831ff093787..ccbb9be4bc59 100644 --- a/docs/src/main/sphinx/admin/properties-resource-management.rst +++ b/docs/src/main/sphinx/admin/properties-resource-management.rst @@ -48,16 +48,6 @@ including revocable memory. When the memory allocated by a query across all workers hits this limit it is killed. The value of ``query.max-total-memory`` must be greater than ``query.max-memory``. -``query.max-memory-per-task`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* **Type:** :ref:`prop-type-data-size` -* **Default value:** none, and therefore unrestricted -* **Session property:** ``query_max_total_memory_per_task`` - -This is the max amount of the memory a task can use on a node in the -cluster. Support for using this property is experimental only. - ``memory.heap-headroom-per-node`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/lib/trino-memory-context/pom.xml b/lib/trino-memory-context/pom.xml index 5d4c8467bb04..c4d77f43a0ce 100644 --- a/lib/trino-memory-context/pom.xml +++ b/lib/trino-memory-context/pom.xml @@ -18,11 +18,6 @@ - - io.airlift - log - - io.airlift units diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java deleted file mode 100644 index 316a2a0b7b16..000000000000 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/MemoryAllocationValidator.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.memory.context; - -public interface MemoryAllocationValidator -{ - MemoryAllocationValidator NO_MEMORY_VALIDATION = new MemoryAllocationValidator() { - @Override - public void reserveMemory(String allocationTag, long delta) {} - - @Override - public boolean tryReserveMemory(String allocationTag, long delta) - { - return true; - } - }; - - /** - * Check if memory can be reserved. Account for reserved memory if reservation is possible. Throw exception otherwise. - */ - void reserveMemory(String allocationTag, long delta); - - /** - * Check if memory can be reserved. Account for reserved memory if reservation is possible and return true. Return false otherwise. - */ - boolean tryReserveMemory(String allocationTag, long delta); -} diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java deleted file mode 100644 index e9d854528cf9..000000000000 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingAggregateContext.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.memory.context; - -import static java.util.Objects.requireNonNull; - -public class ValidatingAggregateContext - implements AggregatedMemoryContext -{ - private final AggregatedMemoryContext delegate; - private final MemoryAllocationValidator memoryValidator; - - public ValidatingAggregateContext(AggregatedMemoryContext delegate, MemoryAllocationValidator memoryValidator) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.memoryValidator = requireNonNull(memoryValidator, "memoryValidator is null"); - } - - @Override - public AggregatedMemoryContext newAggregatedMemoryContext() - { - return new ValidatingAggregateContext(delegate.newAggregatedMemoryContext(), memoryValidator); - } - - @Override - public LocalMemoryContext newLocalMemoryContext(String allocationTag) - { - return new ValidatingLocalMemoryContext(delegate.newLocalMemoryContext(allocationTag), allocationTag, memoryValidator); - } - - @Override - public long getBytes() - { - return delegate.getBytes(); - } - - @Override - public void close() - { - delegate.close(); - } -} diff --git a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java b/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java deleted file mode 100644 index be2ee36f097f..000000000000 --- a/lib/trino-memory-context/src/main/java/io/trino/memory/context/ValidatingLocalMemoryContext.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.memory.context; - -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.log.Logger; - -import static java.util.Objects.requireNonNull; - -public class ValidatingLocalMemoryContext - implements LocalMemoryContext -{ - private static final Logger log = Logger.get(ValidatingLocalMemoryContext.class); - - private final LocalMemoryContext delegate; - private final String allocationTag; - private final MemoryAllocationValidator memoryValidator; - - public ValidatingLocalMemoryContext(LocalMemoryContext delegate, String allocationTag, MemoryAllocationValidator memoryValidator) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.allocationTag = requireNonNull(allocationTag, "allocationTag is null"); - this.memoryValidator = requireNonNull(memoryValidator, "memoryValidator is null"); - } - - @Override - public long getBytes() - { - return delegate.getBytes(); - } - - @Override - public ListenableFuture setBytes(long bytes) - { - long delta = bytes - delegate.getBytes(); - - // first consult validator if allocation is possible - memoryValidator.reserveMemory(allocationTag, delta); - - // update the parent before updating usedBytes as it may throw a runtime exception (e.g., ExceededMemoryLimitException) - try { - // do actual allocation - return delegate.setBytes(bytes); - } - catch (Exception e) { - revertReservationInValidatorSuppressing(allocationTag, delta, e); - throw e; - } - } - - @Override - public boolean trySetBytes(long bytes) - { - long delta = bytes - delegate.getBytes(); - - if (!memoryValidator.tryReserveMemory(allocationTag, delta)) { - return false; - } - - try { - if (delegate.trySetBytes(bytes)) { - return true; - } - } - catch (Exception e) { - revertReservationInValidatorSuppressing(allocationTag, delta, e); - throw e; - } - - revertReservationInValidator(allocationTag, delta); - return false; - } - - @Override - public void close() - { - delegate.close(); - } - - private void revertReservationInValidatorSuppressing(String allocationTag, long delta, Exception revertCause) - { - try { - revertReservationInValidator(allocationTag, delta); - } - catch (Exception suppressed) { - log.warn(suppressed, "Could not rollback memory reservation within allocation validator"); - if (suppressed != revertCause) { - revertCause.addSuppressed(suppressed); - } - } - } - - private void revertReservationInValidator(String allocationTag, long delta) - { - memoryValidator.reserveMemory(allocationTag, -delta); - } -} diff --git a/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java b/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java index 4b758cc61b6a..df54ce687e1e 100644 --- a/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java +++ b/lib/trino-memory-context/src/test/java/io/trino/memory/context/TestMemoryContexts.java @@ -22,7 +22,6 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; @@ -165,135 +164,16 @@ public void testClosedAggregateMemoryContext() localContext.setBytes(100); } - @Test - public void testValidatingAggregateContext() - { - TestMemoryReservationHandler reservationHandler = new TestMemoryReservationHandler(1_000, true); - AggregatedMemoryContext rootContext = newRootAggregatedMemoryContext(reservationHandler, GUARANTEED_MEMORY); - - AggregatedMemoryContext childContext = new ValidatingAggregateContext(rootContext, new TestAllocationValidator(500)); - - LocalMemoryContext localContext = childContext.newLocalMemoryContext("test"); - - assertEquals(localContext.setBytes(500), NOT_BLOCKED); - assertEquals(localContext.getBytes(), 500); - assertEquals(rootContext.getBytes(), 500); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // reserve above validator limit - assertThatThrownBy(() -> localContext.setBytes(501)).hasMessage("limit exceeded"); - assertEquals(localContext.getBytes(), 500); - assertEquals(rootContext.getBytes(), 500); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // try reserve above validator limit - assertFalse(localContext.trySetBytes(501)); - assertEquals(localContext.getBytes(), 500); - assertEquals(rootContext.getBytes(), 500); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // unreserve a bit - assertEquals(localContext.setBytes(400), NOT_BLOCKED); - assertEquals(localContext.getBytes(), 400); - assertEquals(rootContext.getBytes(), 400); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // unreserve a bit using trySetBytes - assertTrue(localContext.trySetBytes(300)); - assertEquals(localContext.getBytes(), 300); - assertEquals(rootContext.getBytes(), 300); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // another context based directly on rootContext - LocalMemoryContext anotherLocalContext = rootContext.newLocalMemoryContext("another"); - - assertEquals(anotherLocalContext.setBytes(650), NOT_BLOCKED); - // total reservation is 950 at root level now - assertEquals(localContext.getBytes(), 300); - assertEquals(anotherLocalContext.getBytes(), 650); - assertEquals(rootContext.getBytes(), 950); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // exceed root context limit but be within validator boundaries - assertThatThrownBy(() -> localContext.setBytes(400)).hasMessage("out of memory"); - assertEquals(localContext.getBytes(), 300); - assertEquals(anotherLocalContext.getBytes(), 650); - assertEquals(rootContext.getBytes(), 950); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // exceed root context limit but be within validator boundaries using trySetBytes - assertFalse(localContext.trySetBytes(400)); - assertEquals(localContext.getBytes(), 300); - assertEquals(anotherLocalContext.getBytes(), 650); - assertEquals(rootContext.getBytes(), 950); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // if we free space in root context we can still allocate up to validator imposed limit - assertEquals(anotherLocalContext.setBytes(499), NOT_BLOCKED); - - // reserve using setBytes - assertEquals(localContext.setBytes(400), NOT_BLOCKED); - assertEquals(localContext.getBytes(), 400); - assertEquals(anotherLocalContext.getBytes(), 499); - assertEquals(rootContext.getBytes(), 899); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - - // reserve using trySetBytes - assertEquals(localContext.setBytes(500), NOT_BLOCKED); - assertEquals(localContext.getBytes(), 500); - assertEquals(anotherLocalContext.getBytes(), 499); - assertEquals(rootContext.getBytes(), 999); - assertEquals(reservationHandler.getReservation(), rootContext.getBytes()); - } - - private static class TestAllocationValidator - implements MemoryAllocationValidator - { - private final long limit; - private long reserved; - - public TestAllocationValidator(long limit) - { - this.limit = limit; - } - - @Override - public void reserveMemory(String allocationTag, long delta) - { - if (reserved + delta > limit) { - throw new IllegalArgumentException("limit exceeded"); - } - reserved = reserved + delta; - } - - @Override - public boolean tryReserveMemory(String allocationTag, long delta) - { - if (reserved + delta > limit) { - return false; - } - reserved = reserved + delta; - return true; - } - } - private static class TestMemoryReservationHandler implements MemoryReservationHandler { private long reservation; private final long maxMemory; - private final boolean throwWhenExceeded; private SettableFuture future; public TestMemoryReservationHandler(long maxMemory) - { - this(maxMemory, false); - } - - public TestMemoryReservationHandler(long maxMemory, boolean throwWhenExceeded) { this.maxMemory = maxMemory; - this.throwWhenExceeded = throwWhenExceeded; } public long getReservation() @@ -304,9 +184,6 @@ public long getReservation() @Override public ListenableFuture reserveMemory(String allocationTag, long delta) { - if (delta > 0 && reservation + delta > maxMemory && throwWhenExceeded) { - throw new IllegalStateException("out of memory"); - } reservation += delta; if (delta >= 0) { if (reservation >= maxMemory) { diff --git a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java index 9b20bfa9ba5a..65a580694784 100644 --- a/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java +++ b/testing/trino-benchmark/src/main/java/io/trino/benchmark/AbstractOperatorBenchmark.java @@ -300,7 +300,6 @@ protected Map runOnce() TaskContext taskContext = new QueryContext( new QueryId("test"), DataSize.of(256, MEGABYTE), - Optional.empty(), memoryPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), diff --git a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java index fe8d5607e6f5..488e51a1c3fd 100644 --- a/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java +++ b/testing/trino-benchmark/src/test/java/io/trino/benchmark/MemoryLocalQueryRunner.java @@ -72,7 +72,6 @@ public List execute(@Language("SQL") String query) QueryContext queryContext = new QueryContext( new QueryId("test"), DataSize.of(1, GIGABYTE), - Optional.empty(), memoryPool, new TestingGcMonitor(), localQueryRunner.getExecutor(), From b9f82a3c26e02adb518c42d2fd17edbec3ca06e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Tue, 14 Dec 2021 23:25:54 +0100 Subject: [PATCH 07/11] Use explicit lease object in NodeAllocator --- .../FaultTolerantStageScheduler.java | 35 +- .../FixedCountNodeAllocatorService.java | 59 +++- .../execution/scheduler/NodeAllocator.java | 17 +- .../TestFaultTolerantStageScheduler.java | 16 +- .../TestFixedCountNodeAllocator.java | 316 +++++++++--------- 5 files changed, 237 insertions(+), 206 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index 9a05f50c49fe..e11c8a1567a1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -107,7 +107,7 @@ public class FaultTolerantStageScheduler private ListenableFuture blocked = immediateVoidFuture(); @GuardedBy("this") - private ListenableFuture acquireNodeFuture; + private NodeAllocator.NodeLease nodeLease; @GuardedBy("this") private SettableFuture taskFinishedFuture; @@ -120,7 +120,7 @@ public class FaultTolerantStageScheduler @GuardedBy("this") private final Map runningTasks = new HashMap<>(); @GuardedBy("this") - private final Map runningNodes = new HashMap<>(); + private final Map runningNodes = new HashMap<>(); @GuardedBy("this") private final Set allPartitions = new HashSet<>(); @GuardedBy("this") @@ -253,15 +253,14 @@ public synchronized void schedule() } TaskDescriptor taskDescriptor = taskDescriptorOptional.get(); - if (acquireNodeFuture == null) { - acquireNodeFuture = nodeAllocator.acquire(taskDescriptor.getNodeRequirements()); + if (nodeLease == null) { + nodeLease = nodeAllocator.acquire(taskDescriptor.getNodeRequirements()); } - if (!acquireNodeFuture.isDone()) { - blocked = asVoid(acquireNodeFuture); + if (!nodeLease.getNode().isDone()) { + blocked = asVoid(nodeLease.getNode()); return; } - NodeInfo node = getFutureValue(acquireNodeFuture); - acquireNodeFuture = null; + NodeInfo node = getFutureValue(nodeLease.getNode()); queuedPartitions.poll(); @@ -311,7 +310,8 @@ public synchronized void schedule() partitionToRemoteTaskMap.put(partition, task); runningTasks.put(task.getTaskId(), task); - runningNodes.put(task.getTaskId(), node); + runningNodes.put(task.getTaskId(), nodeLease); + nodeLease = null; if (taskFinishedFuture == null) { taskFinishedFuture = SettableFuture.create(); @@ -402,16 +402,13 @@ private void cancelBlockedFuture() private void releaseAcquiredNode() { verify(!Thread.holdsLock(this)); - ListenableFuture future; + NodeAllocator.NodeLease lease; synchronized (this) { - future = acquireNodeFuture; - acquireNodeFuture = null; + lease = nodeLease; + nodeLease = null; } - if (future != null) { - future.cancel(true); - if (future.isDone() && !future.isCancelled()) { - nodeAllocator.release(getFutureValue(future)); - } + if (lease != null) { + lease.release(); } } @@ -497,8 +494,8 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional "node not found for task id: " + taskId); - nodeAllocator.release(node); + NodeAllocator.NodeLease nodeLease = requireNonNull(runningNodes.remove(taskId), () -> "node not found for task id: " + taskId); + nodeLease.release(); int partitionId = taskId.getPartitionId(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java index 65fef12915eb..48ea2585d9aa 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountNodeAllocatorService.java @@ -46,6 +46,7 @@ import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.execution.scheduler.NodeInfo.unlimitedMemoryNode; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; @@ -148,30 +149,23 @@ public FixedCountNodeAllocator( } @Override - public synchronized ListenableFuture acquire(NodeRequirements requirements) + public synchronized NodeLease acquire(NodeRequirements requirements) { try { Optional node = tryAcquireNode(requirements); if (node.isPresent()) { - return immediateFuture(unlimitedMemoryNode(node.get())); + return new FixedCountNodeLease(immediateFuture(unlimitedMemoryNode(node.get()))); } } catch (RuntimeException e) { - return immediateFailedFuture(e); + return new FixedCountNodeLease(immediateFailedFuture(e)); } SettableFuture future = SettableFuture.create(); PendingAcquire pendingAcquire = new PendingAcquire(requirements, future); pendingAcquires.add(pendingAcquire); - return future; - } - - @Override - public void release(NodeInfo node) - { - releaseNodeInternal(node.getNode()); - processPendingAcquires(); + return new FixedCountNodeLease(future); } public void updateNodes() @@ -208,10 +202,13 @@ private synchronized Optional tryAcquireNode(NodeRequirements requ return selectedNode; } - private synchronized void releaseNodeInternal(InternalNode node) + private void releaseNode(InternalNode node) { - int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); - checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); + synchronized (this) { + int allocationCount = allocationCountMap.compute(node, (key, value) -> value == null ? 0 : value - 1); + checkState(allocationCount >= 0, "allocation count for node %s is expected to be greater than or equal to zero: %s", node, allocationCount); + } + processPendingAcquires(); } private void processPendingAcquires() @@ -247,7 +244,7 @@ private void processPendingAcquires() SettableFuture future = pendingAcquire.getFuture(); future.set(unlimitedMemoryNode(node)); if (future.isCancelled()) { - releaseNodeInternal(node); + releaseNode(node); } }); @@ -262,6 +259,38 @@ public synchronized void close() { allocators.remove(this); } + + private class FixedCountNodeLease + implements NodeAllocator.NodeLease + { + private final ListenableFuture node; + private final AtomicBoolean released = new AtomicBoolean(); + + private FixedCountNodeLease(ListenableFuture node) + { + this.node = requireNonNull(node, "node is null"); + } + + @Override + public ListenableFuture getNode() + { + return node; + } + + @Override + public void release() + { + if (released.compareAndSet(false, true)) { + node.cancel(true); + if (node.isDone() && !node.isCancelled()) { + releaseNode(getFutureValue(node).getNode()); + } + } + else { + throw new IllegalStateException("Node " + node + " already released"); + } + } + } } private static class PendingAcquire diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java index d4fadf168507..9cc2afac024e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeAllocator.java @@ -20,10 +20,21 @@ public interface NodeAllocator extends Closeable { - ListenableFuture acquire(NodeRequirements requirements); - - void release(NodeInfo node); + /** + * Requests acquisition of node. Obtained node can be obtained via {@link NodeLease#getNode()} method. + * The node may not be available immediately. Calling party needs to wait until future returned is done. + * + * It is obligatory for the calling party to release all the leases they obtained via {@link NodeLease#release()}. + */ + NodeLease acquire(NodeRequirements requirements); @Override void close(); + + interface NodeLease + { + ListenableFuture getNode(); + + void release(); + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index fb59d5f75596..6a6774a0030e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -361,14 +361,14 @@ public void testTaskFailure() // waiting on node acquisition assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); + assertUnblocked(acquireNode1.getNode()); + assertUnblocked(acquireNode2.getNode()); assertThatThrownBy(scheduler::schedule) .hasMessageContaining("some failure"); @@ -469,8 +469,8 @@ private void testCancellation(boolean abort) // waiting on node acquisition assertBlocked(blocked); - ListenableFuture acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - ListenableFuture acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); if (abort) { scheduler.abort(); @@ -480,8 +480,8 @@ private void testCancellation(boolean abort) } assertUnblocked(blocked); - assertUnblocked(acquireNode1); - assertUnblocked(acquireNode2); + assertUnblocked(acquireNode1.getNode()); + assertUnblocked(acquireNode2.getNode()); scheduler.schedule(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java index c8ab0f3a3a6e..30de5556708a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.util.concurrent.ListenableFuture; import io.trino.Session; import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; @@ -78,41 +77,41 @@ public void testSingleNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire2.isDone()); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire2.getNode().isDone()); - nodeAllocator.release(acquire1.get()); + acquire1.release(); - assertTrue(acquire2.isDone()); - assertEquals(acquire2.get().getNode(), NODE_1); + assertTrue(acquire2.getNode().isDone()); + assertEquals(acquire2.getNode().get().getNode(), NODE_1); } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire2.isDone()); - assertEquals(acquire2.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire2.getNode().isDone()); + assertEquals(acquire2.getNode().get().getNode(), NODE_1); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire3.isDone()); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire3.getNode().isDone()); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire4.isDone()); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire4.getNode().isDone()); - nodeAllocator.release(acquire2.get()); // NODE_1 - assertTrue(acquire3.isDone()); - assertEquals(acquire3.get().getNode(), NODE_1); + acquire2.release(); // NODE_1 + assertTrue(acquire3.getNode().isDone()); + assertEquals(acquire3.getNode().get().getNode(), NODE_1); - nodeAllocator.release(acquire3.get()); // NODE_1 - assertTrue(acquire4.isDone()); - assertEquals(acquire4.get().getNode(), NODE_1); + acquire3.release(); // NODE_1 + assertTrue(acquire4.getNode().isDone()); + assertEquals(acquire4.getNode().get().getNode(), NODE_1); } } @@ -124,81 +123,81 @@ public void testMultipleNodes() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire2.isDone()); - assertEquals(acquire2.get().getNode(), NODE_2); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire2.getNode().isDone()); + assertEquals(acquire2.getNode().get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire3.isDone()); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire3.getNode().isDone()); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire4.isDone()); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire4.getNode().isDone()); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire5.isDone()); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire5.getNode().isDone()); - nodeAllocator.release(acquire2.get()); // NODE_2 - assertTrue(acquire3.isDone()); - assertEquals(acquire3.get().getNode(), NODE_2); + acquire2.release(); // NODE_2 + assertTrue(acquire3.getNode().isDone()); + assertEquals(acquire3.getNode().get().getNode(), NODE_2); - nodeAllocator.release(acquire1.get()); // NODE_1 - assertTrue(acquire4.isDone()); - assertEquals(acquire4.get().getNode(), NODE_1); + acquire1.release(); // NODE_1 + assertTrue(acquire4.getNode().isDone()); + assertEquals(acquire4.getNode().get().getNode(), NODE_1); - nodeAllocator.release(acquire4.get()); // NODE_1 - assertTrue(acquire5.isDone()); - assertEquals(acquire5.get().getNode(), NODE_1); + acquire4.release(); //NODE_1 + assertTrue(acquire5.getNode().isDone()); + assertEquals(acquire5.getNode().get().getNode(), NODE_1); } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire2.isDone()); - assertEquals(acquire2.get().getNode(), NODE_2); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire2.getNode().isDone()); + assertEquals(acquire2.getNode().get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire3.isDone()); - assertEquals(acquire3.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire3.getNode().isDone()); + assertEquals(acquire3.getNode().get().getNode(), NODE_1); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire4.isDone()); - assertEquals(acquire4.get().getNode(), NODE_2); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire4.getNode().isDone()); + assertEquals(acquire4.getNode().get().getNode(), NODE_2); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire5.isDone()); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire5.getNode().isDone()); - ListenableFuture acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire6.isDone()); + NodeAllocator.NodeLease acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire6.getNode().isDone()); - nodeAllocator.release(acquire2.get()); // NODE_2 - assertTrue(acquire5.isDone()); - assertEquals(acquire5.get().getNode(), NODE_2); + acquire4.release(); // NODE_2 + assertTrue(acquire5.getNode().isDone()); + assertEquals(acquire5.getNode().get().getNode(), NODE_2); - nodeAllocator.release(acquire1.get()); // NODE_1 - assertTrue(acquire6.isDone()); - assertEquals(acquire6.get().getNode(), NODE_1); + acquire3.release(); // NODE_1 + assertTrue(acquire6.getNode().isDone()); + assertEquals(acquire6.getNode().get().getNode(), NODE_1); - ListenableFuture acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire7.isDone()); + NodeAllocator.NodeLease acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire7.getNode().isDone()); - nodeAllocator.release(acquire3.get()); // NODE_1 - assertTrue(acquire7.isDone()); - assertEquals(acquire7.get().getNode(), NODE_1); + acquire6.release(); // NODE_1 + assertTrue(acquire7.getNode().isDone()); + assertEquals(acquire7.getNode().get().getNode(), NODE_1); - nodeAllocator.release(acquire6.get()); // NODE_1 - nodeAllocator.release(acquire5.get()); // NODE_2 - nodeAllocator.release(acquire4.get()); // NODE_2 + acquire7.release(); // NODE_1 + acquire5.release(); // NODE_2 + acquire2.release(); // NODE_2 - ListenableFuture acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire8.isDone()); - assertEquals(acquire8.get().getNode(), NODE_2); + NodeAllocator.NodeLease acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire8.getNode().isDone()); + assertEquals(acquire8.getNode().get().getNode(), NODE_2); } } @@ -214,73 +213,73 @@ public void testCatalogRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); - assertTrue(catalog1acquire1.isDone()); - assertEquals(catalog1acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + assertTrue(catalog1acquire1.getNode().isDone()); + assertEquals(catalog1acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); - assertTrue(catalog1acquire2.isDone()); - assertEquals(catalog1acquire2.get().getNode(), NODE_3); + NodeAllocator.NodeLease catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + assertTrue(catalog1acquire2.getNode().isDone()); + assertEquals(catalog1acquire2.getNode().get().getNode(), NODE_3); - ListenableFuture catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); - assertFalse(catalog1acquire3.isDone()); + NodeAllocator.NodeLease catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + assertFalse(catalog1acquire3.getNode().isDone()); - ListenableFuture catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); - assertTrue(catalog2acquire1.isDone()); - assertEquals(catalog2acquire1.get().getNode(), NODE_2); + NodeAllocator.NodeLease catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + assertTrue(catalog2acquire1.getNode().isDone()); + assertEquals(catalog2acquire1.getNode().get().getNode(), NODE_2); - ListenableFuture catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); - assertFalse(catalog2acquire2.isDone()); + NodeAllocator.NodeLease catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + assertFalse(catalog2acquire2.getNode().isDone()); - nodeAllocator.release(catalog2acquire1.get()); // NODE_2 - assertFalse(catalog1acquire3.isDone()); - assertTrue(catalog2acquire2.isDone()); - assertEquals(catalog2acquire2.get().getNode(), NODE_2); + catalog2acquire1.release(); // NODE_2 + assertFalse(catalog1acquire3.getNode().isDone()); + assertTrue(catalog2acquire2.getNode().isDone()); + assertEquals(catalog2acquire2.getNode().get().getNode(), NODE_2); - nodeAllocator.release(catalog1acquire1.get()); // NODE_1 - assertTrue(catalog1acquire3.isDone()); - assertEquals(catalog1acquire3.get().getNode(), NODE_1); + catalog1acquire1.release(); // NODE_1 + assertTrue(catalog1acquire3.getNode().isDone()); + assertEquals(catalog1acquire3.getNode().get().getNode(), NODE_1); - ListenableFuture catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); - assertFalse(catalog1acquire4.isDone()); + NodeAllocator.NodeLease catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + assertFalse(catalog1acquire4.getNode().isDone()); - ListenableFuture catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); - assertFalse(catalog2acquire4.isDone()); + NodeAllocator.NodeLease catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + assertFalse(catalog2acquire4.getNode().isDone()); - nodeAllocator.release(catalog1acquire2.get()); // NODE_3 - assertFalse(catalog2acquire4.isDone()); - assertTrue(catalog1acquire4.isDone()); - assertEquals(catalog1acquire4.get().getNode(), NODE_3); + catalog1acquire2.release(); // NODE_3 + assertFalse(catalog2acquire4.getNode().isDone()); + assertTrue(catalog1acquire4.getNode().isDone()); + assertEquals(catalog1acquire4.getNode().get().getNode(), NODE_3); - nodeAllocator.release(catalog1acquire4.get()); // NODE_3 - assertTrue(catalog2acquire4.isDone()); - assertEquals(catalog2acquire4.get().getNode(), NODE_3); + catalog1acquire4.release(); // NODE_3 + assertTrue(catalog2acquire4.getNode().isDone()); + assertEquals(catalog2acquire4.getNode().get().getNode(), NODE_3); } } @Test - public void testCancellation() + public void testReleaseBeforeAcquired() throws Exception { TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of())); setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire2.isDone()); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire2.getNode().isDone()); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire3.isDone()); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire3.getNode().isDone()); - acquire2.cancel(true); + acquire2.release(); - nodeAllocator.release(acquire1.get()); // NODE_1 - assertTrue(acquire3.isDone()); - assertEquals(acquire3.get().getNode(), NODE_1); + acquire1.release(); // NODE_1 + assertTrue(acquire3.getNode().isDone()); + assertEquals(acquire3.getNode().get().getNode(), NODE_1); } } @@ -292,17 +291,17 @@ public void testAddNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire2.isDone()); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire2.getNode().isDone()); nodeSupplier.addNode(NODE_2, ImmutableList.of()); nodeAllocatorService.updateNodes(); - assertEquals(acquire2.get(10, SECONDS).getNode(), NODE_2); + assertEquals(acquire2.getNode().get(10, SECONDS).getNode(), NODE_2); } } @@ -314,24 +313,24 @@ public void testRemoveNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_1); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_1); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire2.isDone()); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire2.getNode().isDone()); nodeSupplier.removeNode(NODE_1); nodeSupplier.addNode(NODE_2, ImmutableList.of()); nodeAllocatorService.updateNodes(); - assertEquals(acquire2.get(10, SECONDS).getNode(), NODE_2); + assertEquals(acquire2.getNode().get(10, SECONDS).getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertFalse(acquire3.isDone()); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + assertFalse(acquire3.getNode().isDone()); - nodeAllocator.release(acquire1.get()); // NODE_1 - assertFalse(acquire3.isDone()); + acquire1.release(); // NODE_1 + assertFalse(acquire3.getNode().isDone()); } } @@ -343,44 +342,39 @@ public void testAddressRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - ListenableFuture acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); - assertTrue(acquire1.isDone()); - assertEquals(acquire1.get().getNode(), NODE_2); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + assertTrue(acquire1.getNode().isDone()); + assertEquals(acquire1.getNode().get().getNode(), NODE_2); - ListenableFuture acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); - assertFalse(acquire2.isDone()); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + assertFalse(acquire2.getNode().isDone()); - nodeAllocator.release(acquire1.get()); // NODE_2 + acquire1.release(); // NODE_2 - assertTrue(acquire2.isDone()); - assertEquals(acquire2.get().getNode(), NODE_2); + assertTrue(acquire2.getNode().isDone()); + assertEquals(acquire2.getNode().get().getNode(), NODE_2); - ListenableFuture acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); - assertTrue(acquire3.isDone()); - assertThatThrownBy(acquire3::get) + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + assertTrue(acquire3.getNode().isDone()); + assertThatThrownBy(() -> acquire3.getNode().get()) .hasMessageContaining("No nodes available to run query"); nodeSupplier.addNode(NODE_3, ImmutableList.of()); nodeAllocatorService.updateNodes(); - ListenableFuture acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); - assertTrue(acquire4.isDone()); - assertEquals(acquire4.get().getNode(), NODE_3); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + assertTrue(acquire4.getNode().isDone()); + assertEquals(acquire4.getNode().get().getNode(), NODE_3); - ListenableFuture acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); - assertFalse(acquire5.isDone()); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + assertFalse(acquire5.getNode().isDone()); nodeSupplier.removeNode(NODE_3); nodeAllocatorService.updateNodes(); - assertTrue(acquire5.isDone()); - assertThatThrownBy(acquire5::get) + assertTrue(acquire5.getNode().isDone()); + assertThatThrownBy(() -> acquire5.getNode().get()) .hasMessageContaining("No nodes available to run query"); } } - - private NodeScheduler createNodeScheduler(TestingNodeSupplier testingNodeSupplier) - { - return new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier)); - } } From 444211157ede702a64a247100c54e20afad1a729 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Wed, 2 Mar 2022 14:50:44 -0800 Subject: [PATCH 08/11] Add memory to task requirements passed to node allocator --- .../io/trino/SystemSessionProperties.java | 11 +++ .../execution/scheduler/NodeRequirements.java | 20 +++- .../scheduler/StageTaskSourceFactory.java | 47 ++++++--- .../io/trino/memory/MemoryManagerConfig.java | 17 ++++ .../TestFaultTolerantStageScheduler.java | 9 +- .../TestFixedCountNodeAllocator.java | 81 +++++++-------- .../scheduler/TestStageTaskSourceFactory.java | 99 ++++++++++--------- .../scheduler/TestTaskDescriptorStorage.java | 3 +- .../scheduler/TestingTaskSourceFactory.java | 4 +- .../trino/memory/TestMemoryManagerConfig.java | 7 +- 10 files changed, 188 insertions(+), 110 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 8368cf1e2716..7b3c98dcf3db 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -153,6 +153,7 @@ public final class SystemSessionProperties public static final String FAULT_TOLERANT_EXECUTION_MIN_TASK_SPLIT_COUNT = "fault_tolerant_execution_min_task_split_count"; public static final String FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT = "fault_tolerant_execution_target_task_split_count"; public static final String FAULT_TOLERANT_EXECUTION_MAX_TASK_SPLIT_COUNT = "fault_tolerant_execution_max_task_split_count"; + public static final String FAULT_TOLERANT_EXECUTION_TASK_MEMORY = "fault_tolerant_execution_task_memory"; private final List> sessionProperties; @@ -721,6 +722,11 @@ public SystemSessionProperties( FAULT_TOLERANT_EXECUTION_MAX_TASK_SPLIT_COUNT, "Maximal number of splits for a single fault tolerant task (count based)", queryManagerConfig.getFaultTolerantExecutionMaxTaskSplitCount(), + false), + dataSizeProperty( + FAULT_TOLERANT_EXECUTION_TASK_MEMORY, + "Estimated amount of memory a single task will use when task level retries are used; value is used allocating nodes for tasks execution", + memoryManagerConfig.getFaultTolerantTaskMemory(), false)); } @@ -1297,4 +1303,9 @@ public static int getFaultTolerantExecutionMaxTaskSplitCount(Session session) { return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_MAX_TASK_SPLIT_COUNT, Integer.class); } + + public static DataSize getFaultTolerantExecutionDefaultTaskMemory(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TASK_MEMORY, DataSize.class); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java index 5ce6e1e126c4..34ecc2db7e3e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeRequirements.java @@ -14,6 +14,7 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; import io.trino.connector.CatalogName; import io.trino.spi.HostAddress; import org.openjdk.jol.info.ClassLayout; @@ -33,11 +34,13 @@ public class NodeRequirements private final Optional catalogName; private final Set addresses; + private final DataSize memory; - public NodeRequirements(Optional catalogName, Set addresses) + public NodeRequirements(Optional catalogName, Set addresses, DataSize memory) { this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.addresses = ImmutableSet.copyOf(requireNonNull(addresses, "addresses is null")); + this.memory = requireNonNull(memory, "memory is null"); } /* @@ -56,6 +59,16 @@ public Set getAddresses() return addresses; } + public DataSize getMemory() + { + return memory; + } + + public NodeRequirements withMemory(DataSize memory) + { + return new NodeRequirements(catalogName, addresses, memory); + } + @Override public boolean equals(Object o) { @@ -66,13 +79,13 @@ public boolean equals(Object o) return false; } NodeRequirements that = (NodeRequirements) o; - return Objects.equals(catalogName, that.catalogName) && Objects.equals(addresses, that.addresses); + return Objects.equals(catalogName, that.catalogName) && Objects.equals(addresses, that.addresses) && Objects.equals(memory, that.memory); } @Override public int hashCode() { - return Objects.hash(catalogName, addresses); + return Objects.hash(catalogName, addresses, memory); } @Override @@ -81,6 +94,7 @@ public String toString() return toStringHelper(this) .add("catalogName", catalogName) .add("addresses", addresses) + .add("memory", memory) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java index 387f8dd35ef7..86deabc9445d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java @@ -69,6 +69,7 @@ import static com.google.common.collect.Sets.union; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinTaskSplitCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; @@ -128,10 +129,11 @@ public TaskSource create( PartitioningHandle partitioning = fragment.getPartitioning(); if (partitioning.equals(SINGLE_DISTRIBUTION)) { - return SingleDistributionTaskSource.create(fragment, exchangeSourceHandles); + return SingleDistributionTaskSource.create(session, fragment, exchangeSourceHandles); } if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_DISTRIBUTION)) { return ArbitraryDistributionTaskSource.create( + session, fragment, sourceExchanges, exchangeSourceHandles, @@ -175,16 +177,18 @@ public static class SingleDistributionTaskSource private final ListMultimap exchangeSourceHandles; private boolean finished; + private DataSize taskMemory; - public static SingleDistributionTaskSource create(PlanFragment fragment, Multimap exchangeSourceHandles) + public static SingleDistributionTaskSource create(Session session, PlanFragment fragment, Multimap exchangeSourceHandles) { checkArgument(fragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", fragment.getPartitionedSources()); - return new SingleDistributionTaskSource(getInputsForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)); + return new SingleDistributionTaskSource(getInputsForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles), getFaultTolerantExecutionDefaultTaskMemory(session)); } - public SingleDistributionTaskSource(ListMultimap exchangeSourceHandles) + public SingleDistributionTaskSource(ListMultimap exchangeSourceHandles, DataSize taskMemory) { this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); + this.taskMemory = requireNonNull(taskMemory, "taskMemory is null"); } @Override @@ -194,7 +198,7 @@ public List getMoreTasks() 0, ImmutableListMultimap.of(), exchangeSourceHandles, - new NodeRequirements(Optional.empty(), ImmutableSet.of()))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), taskMemory))); finished = true; return result; } @@ -218,10 +222,12 @@ public static class ArbitraryDistributionTaskSource private final Multimap partitionedExchangeSourceHandles; private final Multimap replicatedExchangeSourceHandles; private final long targetPartitionSizeInBytes; + private DataSize taskMemory; private boolean finished; public static ArbitraryDistributionTaskSource create( + Session session, PlanFragment fragment, Map sourceExchanges, Multimap exchangeSourceHandles, @@ -234,18 +240,21 @@ public static ArbitraryDistributionTaskSource create( exchangeForHandleMap, getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles), getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), - targetPartitionSize); + targetPartitionSize, + getFaultTolerantExecutionDefaultTaskMemory(session)); } public ArbitraryDistributionTaskSource( IdentityHashMap sourceExchanges, Multimap partitionedExchangeSourceHandles, Multimap replicatedExchangeSourceHandles, - DataSize targetPartitionSize) + DataSize targetPartitionSize, + DataSize taskMemory) { this.sourceExchanges = new IdentityHashMap<>(requireNonNull(sourceExchanges, "sourceExchanges is null")); this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null")); this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); + this.taskMemory = requireNonNull(taskMemory, "taskMemory is null"); checkArgument( sourceExchanges.keySet().containsAll(partitionedExchangeSourceHandles.values()), "Unexpected entries in partitionedExchangeSourceHandles map: %s; allowed keys: %s", @@ -262,7 +271,7 @@ public ArbitraryDistributionTaskSource( @Override public List getMoreTasks() { - NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), ImmutableSet.of()); + NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), ImmutableSet.of(), taskMemory); ImmutableList.Builder result = ImmutableList.builder(); int currentPartitionId = 0; @@ -336,6 +345,7 @@ public static class HashDistributionTaskSource private final LongConsumer getSplitTimeRecorder; private final int[] bucketToPartitionMap; private final Optional bucketNodeMap; + private final DataSize taskMemory; private final Optional catalogRequirement; private final long targetPartitionSourceSizeInBytes; // compared data read from ExchangeSources private final long targetPartitionSplitWeight; // compared against splits from SplitSources @@ -369,7 +379,8 @@ public static HashDistributionTaskSource create( bucketToPartitionMap, bucketNodeMap, fragment.getPartitioning().getConnectorId(), - targetPartitionSplitWeight, targetPartitionSourceSize); + targetPartitionSplitWeight, targetPartitionSourceSize, + getFaultTolerantExecutionDefaultTaskMemory(session)); } public HashDistributionTaskSource( @@ -383,7 +394,8 @@ public HashDistributionTaskSource( Optional bucketNodeMap, Optional catalogRequirement, long targetPartitionSplitWeight, - DataSize targetPartitionSourceSize) + DataSize targetPartitionSourceSize, + DataSize taskMemory) { this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null")); this.exchangeForHandle = new IdentityHashMap<>(); @@ -394,6 +406,7 @@ public HashDistributionTaskSource( this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); this.bucketToPartitionMap = requireNonNull(bucketToPartitionMap, "bucketToPartitionMap is null"); this.bucketNodeMap = requireNonNull(bucketNodeMap, "bucketNodeMap is null"); + this.taskMemory = requireNonNull(taskMemory, "taskMemory is null"); checkArgument(bucketNodeMap.isPresent() || splitSources.isEmpty(), "bucketNodeMap is expected to be set when the fragment reads partitioned sources (tables)"); this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); this.targetPartitionSourceSizeInBytes = requireNonNull(targetPartitionSourceSize, "targetPartitionSourceSize is null").toBytes(); @@ -460,7 +473,7 @@ public List getMoreTasks() .build(); HostAddress host = partitionToNodeMap.get(partition); Set hostRequirement = host == null ? ImmutableSet.of() : ImmutableSet.of(host); - partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement))); + partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits, exchangeSourceHandles, new NodeRequirements(catalogRequirement, hostRequirement, taskMemory))); } List result = postprocessTasks(partitionTasks.build()); @@ -574,6 +587,7 @@ public static class SourceDistributionTaskSource private final int minPartitionSplitCount; private final long targetPartitionSplitWeight; private final int maxPartitionSplitCount; + private final DataSize taskMemory; private final Set remotelyAccessibleSplitBuffer = newIdentityHashSet(); private final Map> locallyAccessibleSplitBuffer = new HashMap<>(); @@ -617,7 +631,8 @@ public static SourceDistributionTaskSource create( catalogName, minPartitionSplitCount, targetPartitionSplitWeight, - maxPartitionSplitCount); + maxPartitionSplitCount, + getFaultTolerantExecutionDefaultTaskMemory(session)); } public SourceDistributionTaskSource( @@ -631,7 +646,8 @@ public SourceDistributionTaskSource( Optional catalogRequirement, int minPartitionSplitCount, long targetPartitionSplitWeight, - int maxPartitionSplitCount) + int maxPartitionSplitCount, + DataSize taskMemory) { this.queryId = requireNonNull(queryId, "queryId is null"); this.partitionedSourceNodeId = requireNonNull(partitionedSourceNodeId, "partitionedSourceNodeId is null"); @@ -651,6 +667,7 @@ public SourceDistributionTaskSource( maxPartitionSplitCount, minPartitionSplitCount); this.maxPartitionSplitCount = maxPartitionSplitCount; + this.taskMemory = requireNonNull(taskMemory, "taskMemory is null"); } @Override @@ -668,7 +685,7 @@ public List getMoreTasks() result.addAll(getReadyTasks( remotelyAccessibleSplitBuffer, ImmutableList.of(), - new NodeRequirements(catalogRequirement, ImmutableSet.of()), + new NodeRequirements(catalogRequirement, ImmutableSet.of(), taskMemory), includeRemainder)); for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { result.addAll(getReadyTasks( @@ -677,7 +694,7 @@ public List getMoreTasks() .filter(entry -> !entry.getKey().equals(remoteHost)) .map(Map.Entry::getValue) .collect(toImmutableList()), - new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost)), + new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost), taskMemory), includeRemainder)); } diff --git a/core/trino-main/src/main/java/io/trino/memory/MemoryManagerConfig.java b/core/trino-main/src/main/java/io/trino/memory/MemoryManagerConfig.java index 6903037cacdb..419b4eb76700 100644 --- a/core/trino-main/src/main/java/io/trino/memory/MemoryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/memory/MemoryManagerConfig.java @@ -35,10 +35,13 @@ "resources.reserved-system-memory"}) public class MemoryManagerConfig { + public static final String FAULT_TOLERANT_TASK_MEMORY_CONFIG = "fault-tolerant-task-memory"; + // enforced against user memory allocations private DataSize maxQueryMemory = DataSize.of(20, GIGABYTE); // enforced against user + system memory allocations (default is maxQueryMemory * 2) private DataSize maxQueryTotalMemory; + private DataSize faultTolerantTaskMemory = DataSize.of(1, GIGABYTE); private LowMemoryKillerPolicy lowMemoryKillerPolicy = LowMemoryKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES; private Duration killOnOutOfMemoryDelay = new Duration(5, MINUTES); @@ -98,6 +101,20 @@ public MemoryManagerConfig setMaxQueryTotalMemory(DataSize maxQueryTotalMemory) return this; } + @NotNull + public DataSize getFaultTolerantTaskMemory() + { + return faultTolerantTaskMemory; + } + + @Config(FAULT_TOLERANT_TASK_MEMORY_CONFIG) + @ConfigDescription("Estimated amount of memory a single task will use when task level retries are used; value is used allocating nodes for tasks execution") + public MemoryManagerConfig setFaultTolerantTaskMemory(DataSize faultTolerantTaskMemory) + { + this.faultTolerantTaskMemory = faultTolerantTaskMemory; + return this; + } + public enum LowMemoryKillerPolicy { NONE, diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index 6a6774a0030e..054d2d9dd5db 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -64,6 +64,7 @@ import static com.google.common.collect.Iterables.cycle; import static com.google.common.collect.Iterables.limit; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.operator.RetryPolicy.TASK; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; @@ -361,8 +362,8 @@ public void testTaskFailure() // waiting on node acquisition assertBlocked(blocked); - NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))); + NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))); remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); @@ -469,8 +470,8 @@ private void testCancellation(boolean abort) // waiting on node acquisition assertBlocked(blocked); - NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); - NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())); + NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))); + NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))); if (abort) { scheduler.abort(); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java index 30de5556708a..37d2fadc56eb 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFixedCountNodeAllocator.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; import io.trino.Session; import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; @@ -28,6 +29,7 @@ import java.net.URI; import java.util.Optional; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.trino.testing.TestingSession.testSessionBuilder; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -51,6 +53,7 @@ public class TestFixedCountNodeAllocator private static final CatalogName CATALOG_1 = new CatalogName("catalog1"); private static final CatalogName CATALOG_2 = new CatalogName("catalog2"); + private static final DataSize MEMORY_REQUIREMENTS = DataSize.of(4, GIGABYTE); private FixedCountNodeAllocatorService nodeAllocatorService; @@ -77,11 +80,11 @@ public void testSingleNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire2.getNode().isDone()); acquire1.release(); @@ -91,18 +94,18 @@ public void testSingleNode() } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire2.getNode().isDone()); assertEquals(acquire2.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire3.getNode().isDone()); - NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire4.getNode().isDone()); acquire2.release(); // NODE_1 @@ -123,21 +126,21 @@ public void testMultipleNodes() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire2.getNode().isDone()); assertEquals(acquire2.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire3.getNode().isDone()); - NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire4.getNode().isDone()); - NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire5.getNode().isDone()); acquire2.release(); // NODE_2 @@ -154,26 +157,26 @@ public void testMultipleNodes() } try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 2)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire2.getNode().isDone()); assertEquals(acquire2.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire3.getNode().isDone()); assertEquals(acquire3.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire4.getNode().isDone()); assertEquals(acquire4.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire5.getNode().isDone()); - NodeAllocator.NodeLease acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire6 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire6.getNode().isDone()); acquire4.release(); // NODE_2 @@ -184,7 +187,7 @@ public void testMultipleNodes() assertTrue(acquire6.getNode().isDone()); assertEquals(acquire6.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire7 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire7.getNode().isDone()); acquire6.release(); // NODE_1 @@ -195,7 +198,7 @@ public void testMultipleNodes() acquire5.release(); // NODE_2 acquire2.release(); // NODE_2 - NodeAllocator.NodeLease acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire8 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire8.getNode().isDone()); assertEquals(acquire8.getNode().get().getNode(), NODE_2); } @@ -213,22 +216,22 @@ public void testCatalogRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + NodeAllocator.NodeLease catalog1acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(catalog1acquire1.getNode().isDone()); assertEquals(catalog1acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + NodeAllocator.NodeLease catalog1acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(catalog1acquire2.getNode().isDone()); assertEquals(catalog1acquire2.getNode().get().getNode(), NODE_3); - NodeAllocator.NodeLease catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + NodeAllocator.NodeLease catalog1acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(catalog1acquire3.getNode().isDone()); - NodeAllocator.NodeLease catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + NodeAllocator.NodeLease catalog2acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(catalog2acquire1.getNode().isDone()); assertEquals(catalog2acquire1.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + NodeAllocator.NodeLease catalog2acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(catalog2acquire2.getNode().isDone()); catalog2acquire1.release(); // NODE_2 @@ -240,10 +243,10 @@ public void testCatalogRequirement() assertTrue(catalog1acquire3.getNode().isDone()); assertEquals(catalog1acquire3.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of())); + NodeAllocator.NodeLease catalog1acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_1), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(catalog1acquire4.getNode().isDone()); - NodeAllocator.NodeLease catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of())); + NodeAllocator.NodeLease catalog2acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.of(CATALOG_2), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(catalog2acquire4.getNode().isDone()); catalog1acquire2.release(); // NODE_3 @@ -265,14 +268,14 @@ public void testReleaseBeforeAcquired() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire2.getNode().isDone()); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire3.getNode().isDone()); acquire2.release(); @@ -291,11 +294,11 @@ public void testAddNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire2.getNode().isDone()); nodeSupplier.addNode(NODE_2, ImmutableList.of()); @@ -313,11 +316,11 @@ public void testRemoveNode() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_1); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire2.getNode().isDone()); nodeSupplier.removeNode(NODE_1); @@ -326,7 +329,7 @@ public void testRemoveNode() assertEquals(acquire2.getNode().get(10, SECONDS).getNode(), NODE_2); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of())); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(), MEMORY_REQUIREMENTS)); assertFalse(acquire3.getNode().isDone()); acquire1.release(); // NODE_1 @@ -342,11 +345,11 @@ public void testAddressRequirement() setupNodeAllocatorService(nodeSupplier); try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS), MEMORY_REQUIREMENTS)); assertTrue(acquire1.getNode().isDone()); assertEquals(acquire1.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS))); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_2_ADDRESS), MEMORY_REQUIREMENTS)); assertFalse(acquire2.getNode().isDone()); acquire1.release(); // NODE_2 @@ -354,7 +357,7 @@ public void testAddressRequirement() assertTrue(acquire2.getNode().isDone()); assertEquals(acquire2.getNode().get().getNode(), NODE_2); - NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS), MEMORY_REQUIREMENTS)); assertTrue(acquire3.getNode().isDone()); assertThatThrownBy(() -> acquire3.getNode().get()) .hasMessageContaining("No nodes available to run query"); @@ -362,11 +365,11 @@ public void testAddressRequirement() nodeSupplier.addNode(NODE_3, ImmutableList.of()); nodeAllocatorService.updateNodes(); - NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS), MEMORY_REQUIREMENTS)); assertTrue(acquire4.getNode().isDone()); assertEquals(acquire4.getNode().get().getNode(), NODE_3); - NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS))); + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(new NodeRequirements(Optional.empty(), ImmutableSet.of(NODE_3_ADDRESS), MEMORY_REQUIREMENTS)); assertFalse(acquire5.getNode().isDone()); nodeSupplier.removeNode(NODE_3); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java index 49276ce8360a..48600fe0f265 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java @@ -88,7 +88,7 @@ public void testSingleDistributionTaskSource() .put(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)) .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 222)) .build(); - TaskSource taskSource = new SingleDistributionTaskSource(sources); + TaskSource taskSource = new SingleDistributionTaskSource(sources, DataSize.of(4, GIGABYTE)); assertFalse(taskSource.isFinished()); @@ -113,7 +113,8 @@ public void testArbitraryDistributionTaskSource() TaskSource taskSource = new ArbitraryDistributionTaskSource(new IdentityHashMap<>(), ImmutableListMultimap.of(), ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); assertFalse(taskSource.isFinished()); List tasks = taskSource.getMoreTasks(); assertThat(tasks).isEmpty(); @@ -131,7 +132,8 @@ public void testArbitraryDistributionTaskSource() new IdentityHashMap<>(ImmutableMap.of(sourceHandle3, exchange)), nonReplicatedSources, ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertTrue(taskSource.isFinished()); assertThat(tasks).hasSize(1); @@ -139,7 +141,7 @@ public void testArbitraryDistributionTaskSource() 0, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); nonReplicatedSources = ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123); exchange = nonSplittingExchangeManager.createExchange(new ExchangeContext(new QueryId("query"), createRandomExchangeId()), 3); @@ -147,13 +149,14 @@ public void testArbitraryDistributionTaskSource() new IdentityHashMap<>(ImmutableMap.of(sourceHandle123, exchange)), nonReplicatedSources, ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertEquals(tasks, ImmutableList.of(new TaskDescriptor( 0, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle123, @@ -165,19 +168,20 @@ public void testArbitraryDistributionTaskSource() sourceHandle321, exchange)), nonReplicatedSources, ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle1, @@ -191,7 +195,8 @@ public void testArbitraryDistributionTaskSource() sourceHandle4, exchange)), nonReplicatedSources, ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( @@ -200,17 +205,17 @@ public void testArbitraryDistributionTaskSource() ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 3)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); nonReplicatedSources = ImmutableListMultimap.of( PLAN_NODE_1, sourceHandle1, @@ -224,29 +229,30 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), sourceHandle4, exchange)), nonReplicatedSources, ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( 0, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 3)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 3, ImmutableListMultimap.of(), ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); // with replicated sources nonReplicatedSources = ImmutableListMultimap.of( @@ -264,7 +270,8 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), sourceHandle321, exchange)), nonReplicatedSources, replicatedSources, - DataSize.of(3, BYTE)); + DataSize.of(3, BYTE), + DataSize.of(4, GIGABYTE)); tasks = taskSource.getMoreTasks(); assertEquals(tasks, ImmutableList.of( new TaskDescriptor( @@ -274,21 +281,21 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2)), PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_1, new TestingExchangeSourceHandle(0, 2), PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321)), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 3), PLAN_NODE_2, sourceHandle321), - new NodeRequirements(Optional.empty(), ImmutableSet.of())), + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_2, sourceHandle321), - new NodeRequirements(Optional.empty(), ImmutableSet.of())))); + new NodeRequirements(Optional.empty(), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); } @Test @@ -326,13 +333,13 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new TaskDescriptor(0, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor(1, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor(2, ImmutableListMultimap.of(), ImmutableListMultimap.of( PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); Split bucketedSplit1 = createBucketedSplit(0, 0); @@ -359,25 +366,25 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit1), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of( PLAN_NODE_5, bucketedSplit4), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit2), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 3, ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit3), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); taskSource = createHashDistributionTaskSource( @@ -405,27 +412,27 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of( PLAN_NODE_5, bucketedSplit4), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit2), ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 3, ImmutableListMultimap.of( PLAN_NODE_4, bucketedSplit3), ImmutableListMultimap.of( PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); taskSource = createHashDistributionTaskSource( @@ -452,7 +459,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of( @@ -460,7 +467,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Option PLAN_NODE_5, bucketedSplit4), ImmutableListMultimap.of( PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + PLAN_NODE_3, new TestingExchangeSourceHandle(0, 1)), new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); // join based on split target split weight @@ -493,7 +500,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 1), PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(1, 1), PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of( @@ -503,7 +510,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1), PLAN_NODE_2, new TestingExchangeSourceHandle(3, 1), PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); // join based on target exchange size @@ -536,7 +543,7 @@ PLAN_NODE_1, new TestingExchangeSourceHandle(0, 20), PLAN_NODE_1, new TestingExchangeSourceHandle(1, 30), PLAN_NODE_2, new TestingExchangeSourceHandle(1, 20), PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 1, ImmutableListMultimap.of( @@ -544,7 +551,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), ImmutableListMultimap.of( PLAN_NODE_2, new TestingExchangeSourceHandle(2, 99), PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())), + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))), new TaskDescriptor( 2, ImmutableListMultimap.of( @@ -552,7 +559,7 @@ PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), ImmutableListMultimap.of( PLAN_NODE_2, new TestingExchangeSourceHandle(3, 30), PLAN_NODE_3, new TestingExchangeSourceHandle(17, 1)), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); } @@ -583,7 +590,8 @@ private static HashDistributionTaskSource createHashDistributionTaskSource( bucketNodeMap, Optional.of(CATALOG), targetPartitionSplitWeight, - targetPartitionSourceSize); + targetPartitionSourceSize, + DataSize.of(4, GIGABYTE)); } @Test @@ -609,7 +617,7 @@ public void testSourceDistributionTaskSource() 0, ImmutableListMultimap.of(PLAN_NODE_1, split1), ImmutableListMultimap.of(), - new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of())))); + new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE))))); assertTrue(taskSource.isFinished()); taskSource = createSourceDistributionTaskSource( @@ -624,7 +632,7 @@ public void testSourceDistributionTaskSource() assertThat(tasks).hasSize(2); assertThat(tasks.get(0).getSplits().values()).hasSize(2); assertThat(tasks.get(1).getSplits().values()).hasSize(1); - assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of()))); + assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE)))); assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getExchangeSourceHandles().isEmpty()); assertThat(flattenSplits(tasks)).hasSameEntriesAs(ImmutableMultimap.of( PLAN_NODE_1, split1, @@ -645,7 +653,7 @@ public void testSourceDistributionTaskSource() assertThat(tasks).hasSize(2); assertThat(tasks.get(0).getSplits().values()).hasSize(2); assertThat(tasks.get(1).getSplits().values()).hasSize(1); - assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of()))); + assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(CATALOG), ImmutableSet.of(), DataSize.of(4, GIGABYTE)))); assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getExchangeSourceHandles().equals(replicatedSources)); assertThat(flattenSplits(tasks)).hasSameEntriesAs(ImmutableMultimap.of( PLAN_NODE_1, split1, @@ -801,7 +809,8 @@ private static SourceDistributionTaskSource createSourceDistributionTaskSource( Optional.of(CATALOG), minSplitsPerTask, splitWeightPerTask, - maxSplitsPerTask); + maxSplitsPerTask, + DataSize.of(4, GIGABYTE)); } private static Split createSplit(int id, String... addresses) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java index 7c58de587840..ef371a4c273c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestTaskDescriptorStorage.java @@ -26,6 +26,7 @@ import java.util.Optional; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.trino.spi.StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY; import static org.assertj.core.api.Assertions.assertThat; @@ -177,7 +178,7 @@ private static TaskDescriptor createTaskDescriptor(int partitionId, DataSize ret partitionId, ImmutableListMultimap.of(), ImmutableListMultimap.of(new PlanNodeId("1"), new TestingExchangeSourceHandle(retainedSize.toBytes())), - new NodeRequirements(catalog, ImmutableSet.of())); + new NodeRequirements(catalog, ImmutableSet.of(), DataSize.of(4, GIGABYTE))); } private static Optional getCatalogName(TaskDescriptor descriptor) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java index 3b8ea14b08f2..8d41c732b0b8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; +import io.airlift.units.DataSize; import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.metadata.Split; @@ -39,6 +40,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static java.util.Objects.requireNonNull; @@ -133,7 +135,7 @@ public List getMoreTasks() nextPartitionId.getAndIncrement(), ImmutableListMultimap.of(tableScanPlanNodeId, split), exchangeSourceHandles, - new NodeRequirements(catalogRequirement, ImmutableSet.of())); + new NodeRequirements(catalogRequirement, ImmutableSet.of(), DataSize.of(4, GIGABYTE))); result.add(task); } diff --git a/core/trino-main/src/test/java/io/trino/memory/TestMemoryManagerConfig.java b/core/trino-main/src/test/java/io/trino/memory/TestMemoryManagerConfig.java index 11ddc9122ae8..6ca727e05930 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestMemoryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestMemoryManagerConfig.java @@ -38,7 +38,8 @@ public void testDefaults() .setLowMemoryKillerPolicy(TOTAL_RESERVATION_ON_BLOCKED_NODES) .setKillOnOutOfMemoryDelay(new Duration(5, MINUTES)) .setMaxQueryMemory(DataSize.of(20, GIGABYTE)) - .setMaxQueryTotalMemory(DataSize.of(40, GIGABYTE))); + .setMaxQueryTotalMemory(DataSize.of(40, GIGABYTE)) + .setFaultTolerantTaskMemory(DataSize.of(1, GIGABYTE))); } @Test @@ -49,13 +50,15 @@ public void testExplicitPropertyMappings() .put("query.low-memory-killer.delay", "20s") .put("query.max-memory", "2GB") .put("query.max-total-memory", "3GB") + .put("fault-tolerant-task-memory", "2GB") .buildOrThrow(); MemoryManagerConfig expected = new MemoryManagerConfig() .setLowMemoryKillerPolicy(NONE) .setKillOnOutOfMemoryDelay(new Duration(20, SECONDS)) .setMaxQueryMemory(DataSize.of(2, GIGABYTE)) - .setMaxQueryTotalMemory(DataSize.of(3, GIGABYTE)); + .setMaxQueryTotalMemory(DataSize.of(3, GIGABYTE)) + .setFaultTolerantTaskMemory(DataSize.of(2, GIGABYTE)); assertFullMapping(properties, expected); } From 8cb727f40a0d71f97df4b2466d4d4e385a3b39b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Thu, 13 Jan 2022 14:36:23 +0100 Subject: [PATCH 09/11] Extract memory requirements evolution logic to PartitionMemoryEstimator Different implementations of NodeAllocator require different scheme of seleting memory requirements for a partition on retries. Commit introduce PartitionMemoryEstimator interface and two separate implementations. * ConstantPartitionMemoryEstimator to be used with FixedCountNodeAllocator * FallbackToFullNodePartitionMemoryEstimator to be used with FullNodeCapableNodeAllocator --- .../io/trino/execution/SqlQueryExecution.java | 9 ++++ .../ConstantPartitionMemoryEstimator.java | 36 ++++++++++++++ .../FaultTolerantStageScheduler.java | 19 ++++++- .../scheduler/PartitionMemoryEstimator.java | 49 +++++++++++++++++++ .../scheduler/SqlQueryScheduler.java | 10 +++- .../io/trino/server/CoordinatorModule.java | 3 ++ .../TestFaultTolerantStageScheduler.java | 1 + 7 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/ConstantPartitionMemoryEstimator.java create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionMemoryEstimator.java diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 4d16ef25e6fe..459f008d08ed 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -27,6 +27,7 @@ import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeScheduler; +import io.trino.execution.scheduler.PartitionMemoryEstimator; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.SqlQueryScheduler; import io.trino.execution.scheduler.TaskDescriptorStorage; @@ -101,6 +102,7 @@ public class SqlQueryExecution private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; private final NodeAllocatorService nodeAllocatorService; + private final PartitionMemoryEstimator partitionMemoryEstimator; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -135,6 +137,7 @@ private SqlQueryExecution( NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, NodeAllocatorService nodeAllocatorService, + PartitionMemoryEstimator partitionMemoryEstimator, List planOptimizers, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -163,6 +166,7 @@ private SqlQueryExecution( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); + this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); this.planOptimizers = requireNonNull(planOptimizers, "planOptimizers is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -502,6 +506,7 @@ private void planDistribution(PlanRoot plan) nodePartitioningManager, nodeScheduler, nodeAllocatorService, + partitionMemoryEstimator, remoteTaskFactory, plan.isSummarizeTaskInfos(), scheduleSplitBatchSize, @@ -704,6 +709,7 @@ public static class SqlQueryExecutionFactory private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; private final NodeAllocatorService nodeAllocatorService; + private final PartitionMemoryEstimator partitionMemoryEstimator; private final List planOptimizers; private final PlanFragmenter planFragmenter; private final RemoteTaskFactory remoteTaskFactory; @@ -731,6 +737,7 @@ public static class SqlQueryExecutionFactory NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, NodeAllocatorService nodeAllocatorService, + PartitionMemoryEstimator partitionMemoryEstimator, PlanOptimizersFactory planOptimizersFactory, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory, @@ -759,6 +766,7 @@ public static class SqlQueryExecutionFactory this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); + this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null"); this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); @@ -799,6 +807,7 @@ public QueryExecution createQueryExecution( nodePartitioningManager, nodeScheduler, nodeAllocatorService, + partitionMemoryEstimator, planOptimizers, planFragmenter, remoteTaskFactory, diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ConstantPartitionMemoryEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ConstantPartitionMemoryEstimator.java new file mode 100644 index 000000000000..b457ac2516fd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ConstantPartitionMemoryEstimator.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.spi.ErrorCode; + +public class ConstantPartitionMemoryEstimator + implements PartitionMemoryEstimator +{ + @Override + public MemoryRequirements getInitialMemoryRequirements(Session session, DataSize defaultMemoryLimit) + { + return new MemoryRequirements( + defaultMemoryLimit, + true); + } + + @Override + public MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, ErrorCode errorCode) + { + return previousMemoryRequirements; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index e11c8a1567a1..305b472dba52 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -33,6 +33,7 @@ import io.trino.execution.TaskState; import io.trino.execution.TaskStatus; import io.trino.execution.buffer.OutputBuffers; +import io.trino.execution.scheduler.PartitionMemoryEstimator.MemoryRequirements; import io.trino.failuredetector.FailureDetector; import io.trino.metadata.Split; import io.trino.spi.ErrorCode; @@ -93,6 +94,7 @@ public class FaultTolerantStageScheduler private final TaskSourceFactory taskSourceFactory; private final NodeAllocator nodeAllocator; private final TaskDescriptorStorage taskDescriptorStorage; + private final PartitionMemoryEstimator partitionMemoryEstimator; private final TaskLifecycleListener taskLifecycleListener; // empty when the results are consumed via a direct exchange @@ -129,6 +131,8 @@ public class FaultTolerantStageScheduler private final Set finishedPartitions = new HashSet<>(); @GuardedBy("this") private int remainingRetryAttempts; + @GuardedBy("this") + private Map partitionMemoryRequirements = new HashMap<>(); @GuardedBy("this") private Throwable failure; @@ -142,6 +146,7 @@ public FaultTolerantStageScheduler( TaskSourceFactory taskSourceFactory, NodeAllocator nodeAllocator, TaskDescriptorStorage taskDescriptorStorage, + PartitionMemoryEstimator partitionMemoryEstimator, TaskLifecycleListener taskLifecycleListener, Optional sinkExchange, Optional sinkBucketToPartitionMap, @@ -158,6 +163,7 @@ public FaultTolerantStageScheduler( this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); + this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null"); this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null"); this.sinkBucketToPartitionMap = requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null"); @@ -253,8 +259,11 @@ public synchronized void schedule() } TaskDescriptor taskDescriptor = taskDescriptorOptional.get(); + MemoryRequirements memoryRequirements = partitionMemoryRequirements.computeIfAbsent(partition, ignored -> partitionMemoryEstimator.getInitialMemoryRequirements(session, taskDescriptor.getNodeRequirements().getMemory())); if (nodeLease == null) { - nodeLease = nodeAllocator.acquire(taskDescriptor.getNodeRequirements()); + NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); + nodeRequirements = nodeRequirements.withMemory(memoryRequirements.getRequiredMemory()); + nodeLease = nodeAllocator.acquire(nodeRequirements); } if (!nodeLease.getNode().isDone()) { blocked = asVoid(nodeLease.getNode()); @@ -524,6 +533,14 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional 0 && (errorCode == null || errorCode.getType() != USER_ERROR)) { remainingRetryAttempts--; + + // update memory limits for next attempt + MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId); + verify(memoryLimits != null); + MemoryRequirements newMemoryLimits = partitionMemoryEstimator.getNextRetryMemoryRequirements(session, memoryLimits, errorCode); + partitionMemoryRequirements.put(partitionId, newMemoryLimits); + + // reschedule queuedPartitions.add(partitionId); log.debug("Retrying partition %s for stage %s", partitionId, stage.getStageId()); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionMemoryEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionMemoryEstimator.java new file mode 100644 index 000000000000..bef452cfbcaa --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionMemoryEstimator.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.spi.ErrorCode; + +import static java.util.Objects.requireNonNull; + +public interface PartitionMemoryEstimator +{ + MemoryRequirements getInitialMemoryRequirements(Session session, DataSize defaultMemoryLimit); + + MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, ErrorCode errorCode); + + class MemoryRequirements + { + private final DataSize requiredMemory; + private final boolean limitReached; + + MemoryRequirements(DataSize requiredMemory, boolean limitReached) + { + this.requiredMemory = requireNonNull(requiredMemory, "requiredMemory is null"); + this.limitReached = limitReached; + } + + public DataSize getRequiredMemory() + { + return requiredMemory; + } + + public boolean isLimitReached() + { + return limitReached; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index 4baf816283fe..92005b6b1119 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -175,6 +175,7 @@ public class SqlQueryScheduler private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; private final NodeAllocatorService nodeAllocatorService; + private final PartitionMemoryEstimator partitionMemoryEstimator; private final int splitBatchSize; private final ExecutorService executor; private final ScheduledExecutorService schedulerExecutor; @@ -211,6 +212,7 @@ public SqlQueryScheduler( NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, NodeAllocatorService nodeAllocatorService, + PartitionMemoryEstimator partitionMemoryEstimator, RemoteTaskFactory remoteTaskFactory, boolean summarizeTaskInfo, int splitBatchSize, @@ -233,6 +235,7 @@ public SqlQueryScheduler( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); + this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); this.splitBatchSize = splitBatchSize; this.executor = requireNonNull(queryExecutor, "queryExecutor is null"); this.schedulerExecutor = requireNonNull(schedulerExecutor, "schedulerExecutor is null"); @@ -344,7 +347,8 @@ private synchronized Optional createDistributedStage maxRetryAttempts, schedulerExecutor, schedulerStats, - nodeAllocatorService); + nodeAllocatorService, + partitionMemoryEstimator); break; case QUERY: case NONE: @@ -1744,7 +1748,8 @@ public static FaultTolerantDistributedStagesScheduler create( int retryAttempts, ScheduledExecutorService scheduledExecutorService, SplitSchedulerStats schedulerStats, - NodeAllocatorService nodeAllocatorService) + NodeAllocatorService nodeAllocatorService, + PartitionMemoryEstimator partitionMemoryEstimator) { taskDescriptorStorage.initialize(queryStateMachine.getQueryId()); queryStateMachine.addStateChangeListener(state -> { @@ -1806,6 +1811,7 @@ public static FaultTolerantDistributedStagesScheduler create( taskSourceFactory, nodeAllocator, taskDescriptorStorage, + partitionMemoryEstimator, taskLifecycleListener, exchange, bucketToPartitionCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(), diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index b002c719c3ab..fe182266fcc2 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -60,8 +60,10 @@ import io.trino.execution.resourcegroups.InternalResourceGroupManager; import io.trino.execution.resourcegroups.LegacyResourceGroupConfigurationManager; import io.trino.execution.resourcegroups.ResourceGroupManager; +import io.trino.execution.scheduler.ConstantPartitionMemoryEstimator; import io.trino.execution.scheduler.FixedCountNodeAllocatorService; import io.trino.execution.scheduler.NodeAllocatorService; +import io.trino.execution.scheduler.PartitionMemoryEstimator; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.StageTaskSourceFactory; import io.trino.execution.scheduler.TaskDescriptorStorage; @@ -213,6 +215,7 @@ protected void setup(Binder binder) // node allocator binder.bind(NodeAllocatorService.class).to(FixedCountNodeAllocatorService.class).in(Scopes.SINGLETON); + binder.bind(PartitionMemoryEstimator.class).to(ConstantPartitionMemoryEstimator.class).in(Scopes.SINGLETON); // node monitor binder.bind(ClusterSizeMonitor.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index 054d2d9dd5db..ab74f4ef8b70 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -509,6 +509,7 @@ private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( taskSourceFactory, nodeAllocator, taskDescriptorStorage, + new ConstantPartitionMemoryEstimator(), taskLifecycleListener, sinkExchange, Optional.empty(), From 2ccf7f6b1f9a277dbc53fe7188bebf8cbd619215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Tue, 25 Jan 2022 14:09:33 +0100 Subject: [PATCH 10/11] Add node allocator supporting full-node acquisitions --- ...ackToFullNodePartitionMemoryEstimator.java | 53 ++ .../FullNodeCapableNodeAllocatorService.java | 664 +++++++++++++++++ .../scheduler/NodeSchedulerConfig.java | 61 ++ .../io/trino/memory/ClusterMemoryManager.java | 3 +- .../io/trino/server/CoordinatorModule.java | 21 +- .../execution/TestNodeSchedulerConfig.java | 13 +- .../TestFullNodeCapableNodeAllocator.java | 698 ++++++++++++++++++ 7 files changed, 1507 insertions(+), 6 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FallbackToFullNodePartitionMemoryEstimator.java create mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FullNodeCapableNodeAllocatorService.java create mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestFullNodeCapableNodeAllocator.java diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FallbackToFullNodePartitionMemoryEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FallbackToFullNodePartitionMemoryEstimator.java new file mode 100644 index 000000000000..d0e1fa559aa9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FallbackToFullNodePartitionMemoryEstimator.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.spi.ErrorCode; + +import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY; +import static io.trino.spi.StandardErrorCode.EXCEEDED_LOCAL_MEMORY_LIMIT; + +public class FallbackToFullNodePartitionMemoryEstimator + implements PartitionMemoryEstimator +{ + // temporarily express full-node requirement as huge amount of memory + public static final DataSize FULL_NODE_MEMORY = DataSize.of(512, DataSize.Unit.GIGABYTE); + + private static final MemoryRequirements FULL_NODE_MEMORY_REQUIREMENTS = new MemoryRequirements(FULL_NODE_MEMORY, true); + + @Override + public MemoryRequirements getInitialMemoryRequirements(Session session, DataSize defaultMemoryLimit) + { + return new MemoryRequirements( + defaultMemoryLimit, + false); + } + + @Override + public MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, ErrorCode errorCode) + { + if (shouldRescheduleWithFullNode(errorCode)) { + return FULL_NODE_MEMORY_REQUIREMENTS; + } + return previousMemoryRequirements; + } + + private boolean shouldRescheduleWithFullNode(ErrorCode errorCode) + { + return EXCEEDED_LOCAL_MEMORY_LIMIT.toErrorCode().equals(errorCode) // too many tasks from single query on a node + || CLUSTER_OUT_OF_MEMORY.toErrorCode().equals(errorCode); // too many tasks in general on a node + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FullNodeCapableNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FullNodeCapableNodeAllocatorService.java new file mode 100644 index 000000000000..bf9d14c0b478 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FullNodeCapableNodeAllocatorService.java @@ -0,0 +1,664 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.airlift.log.Logger; +import io.trino.Session; +import io.trino.connector.CatalogName; +import io.trino.memory.ClusterMemoryManager; +import io.trino.memory.MemoryInfo; +import io.trino.metadata.InternalNode; +import io.trino.spi.QueryId; +import io.trino.spi.TrinoException; +import org.assertj.core.util.VisibleForTesting; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import javax.annotation.concurrent.ThreadSafe; +import javax.inject.Inject; + +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Comparator; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static com.google.common.util.concurrent.Futures.transform; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.trino.execution.scheduler.FallbackToFullNodePartitionMemoryEstimator.FULL_NODE_MEMORY; +import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; +import static java.lang.Thread.currentThread; +import static java.util.Comparator.comparing; +import static java.util.Objects.requireNonNull; +import static java.util.function.Predicate.not; + +/** + * Node allocation service which allocates nodes for tasks in two modes: + * - binpacking tasks into nodes + * - reserving full node for a task + * + * The mode selection is handled by {@link FallbackToFullNodePartitionMemoryEstimator}. Initially the task is assigned default memory requirements, + * and binpacking mode of node allocation is used. + * If task execution fails due to out-of-memory error next time full node is requested for the task. + * + * It is possible to configure limit of how many full nodes can be assigned for a give query in the same time. + * Limit may be expressed as absolute value and as fraction of all the nodes in the cluster. + */ +@ThreadSafe +public class FullNodeCapableNodeAllocatorService + implements NodeAllocatorService +{ + private static final Logger log = Logger.get(FullNodeCapableNodeAllocatorService.class); + + @VisibleForTesting + static final int PROCESS_PENDING_ACQUIRES_DELAY_SECONDS = 5; + + private final NodeScheduler nodeScheduler; + private final Supplier>> workerMemoryInfoSupplier; + private final int maxAbsoluteFullNodesPerQuery; + private final double maxFractionFullNodesPerQuery; + + private final List sharedPendingAcquires = new LinkedList<>(); + private final Map fullNodePendingAcquires = new HashMap<>(); + private final Deque detachedFullNodePendingAcquires = new ArrayDeque<>(); + + private final ConcurrentMap sharedAllocatedMemory = new ConcurrentHashMap<>(); + private final Set allocatedFullNodes = new HashSet<>(); + + private final Multimap fullNodesByQueryId = HashMultimap.create(); // both assigned pending and allocated + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(2, daemonThreadsNamed("bin-packing-node-allocator")); + private final AtomicBoolean started = new AtomicBoolean(); + private final AtomicBoolean stopped = new AtomicBoolean(); + private final Semaphore processSemaphore = new Semaphore(0); + private final ConcurrentMap nodePoolSizes = new ConcurrentHashMap<>(); + + @Inject + public FullNodeCapableNodeAllocatorService( + NodeScheduler nodeScheduler, + ClusterMemoryManager clusterMemoryManager, + NodeSchedulerConfig config) + { + this(nodeScheduler, requireNonNull(clusterMemoryManager, "clusterMemoryManager is null")::getWorkerMemoryInfo, config.getMaxAbsoluteFullNodesPerQuery(), config.getMaxFractionFullNodesPerQuery()); + } + + @VisibleForTesting + FullNodeCapableNodeAllocatorService( + NodeScheduler nodeScheduler, + Supplier>> workerMemoryInfoSupplier, + int maxAbsoluteFullNodesPerQuery, + double maxFractionFullNodesPerQuery) + { + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.workerMemoryInfoSupplier = requireNonNull(workerMemoryInfoSupplier, "workerMemoryInfoSupplier is null"); + this.maxAbsoluteFullNodesPerQuery = maxAbsoluteFullNodesPerQuery; + this.maxFractionFullNodesPerQuery = maxFractionFullNodesPerQuery; + } + + private void refreshNodePoolSizes() + { + Map> workerMemoryInfo = workerMemoryInfoSupplier.get(); + for (String key : nodePoolSizes.keySet()) { + if (!workerMemoryInfo.containsKey(key)) { + nodePoolSizes.remove(key); + } + } + + for (Map.Entry> entry : workerMemoryInfo.entrySet()) { + Optional memoryInfo = entry.getValue(); + if (memoryInfo.isEmpty()) { + continue; + } + nodePoolSizes.put(entry.getKey(), memoryInfo.get().getPool().getMaxBytes()); + } + } + + private Optional getNodePoolSize(InternalNode internalNode) + { + return Optional.ofNullable(nodePoolSizes.get(internalNode.getNodeIdentifier())); + } + + @PostConstruct + public void start() + { + if (started.compareAndSet(false, true)) { + executor.schedule(() -> { + while (!stopped.get()) { + try { + // pending acquires are processed when node is released (semaphore is bumped) and periodically (every couple seconds) + // in case node list in cluster have changed. + processSemaphore.tryAcquire(PROCESS_PENDING_ACQUIRES_DELAY_SECONDS, TimeUnit.SECONDS); + processSemaphore.drainPermits(); + processPendingAcquires(); + } + catch (InterruptedException e) { + currentThread().interrupt(); + } + catch (Exception e) { + // ignore to avoid getting unscheduled + log.warn(e, "Error updating nodes"); + } + } + }, 0, TimeUnit.SECONDS); + } + + refreshNodePoolSizes(); + executor.scheduleWithFixedDelay(this::refreshNodePoolSizes, 1, 1, TimeUnit.SECONDS); + } + + @VisibleForTesting + void wakeupProcessPendingAcquires() + { + processSemaphore.release(); + } + + @VisibleForTesting + void processPendingAcquires() + { + processFullNodePendingAcquires(); + processSharedPendingAcquires(); + } + + private void processSharedPendingAcquires() + { + Map assignedNodes = new IdentityHashMap<>(); + Map failures = new IdentityHashMap<>(); + synchronized (this) { + Iterator iterator = sharedPendingAcquires.iterator(); + while (iterator.hasNext()) { + PendingAcquire pendingAcquire = iterator.next(); + if (pendingAcquire.getFuture().isCancelled()) { + iterator.remove(); + continue; + } + try { + Candidates candidates = selectCandidates(pendingAcquire.getNodeRequirements(), pendingAcquire.getNodeSelector()); + if (candidates.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + Optional node = tryAcquireSharedNode(candidates, pendingAcquire.getMemoryLease()); + if (node.isPresent()) { + iterator.remove(); + assignedNodes.put(pendingAcquire, node.get()); + } + } + catch (RuntimeException e) { + iterator.remove(); + failures.put(pendingAcquire, e); + } + } + } + + // complete futures outside of synchronized section + checkState(!Thread.holdsLock(this), "Cannot complete node futures under lock"); + assignedNodes.forEach((pendingAcquire, node) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.set(node); + if (future.isCancelled()) { + releaseSharedNode(node, pendingAcquire.getMemoryLease()); + } + }); + + failures.forEach((pendingAcquire, failure) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.setException(failure); + }); + } + + private void processFullNodePendingAcquires() + { + Map assignedNodes = new IdentityHashMap<>(); + Map failures = new IdentityHashMap<>(); + + synchronized (this) { + Iterator detachedIterator = detachedFullNodePendingAcquires.iterator(); + while (detachedIterator.hasNext()) { + PendingAcquire pendingAcquire = detachedIterator.next(); + try { + if (pendingAcquire.getFuture().isCancelled()) { + // discard cancelled detached pendingAcquire + detachedIterator.remove(); + continue; + } + + Candidates currentCandidates = selectCandidates(pendingAcquire.getNodeRequirements(), pendingAcquire.getNodeSelector()); + if (currentCandidates.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + Optional target = findTargetPendingFullNode(pendingAcquire.getQueryId(), currentCandidates); + if (target.isEmpty()) { + // leave pendingAcquire as pending + continue; + } + + // move pendingAcquire to fullNodePendingAcquires + fullNodePendingAcquires.put(target.get(), pendingAcquire); + fullNodesByQueryId.put(pendingAcquire.getQueryId(), target.get()); + + detachedIterator.remove(); + } + catch (RuntimeException e) { + failures.put(pendingAcquire, e); + detachedIterator.remove(); + } + } + + Set nodes = ImmutableSet.copyOf(fullNodePendingAcquires.keySet()); + for (InternalNode reservedNode : nodes) { + PendingAcquire pendingAcquire = fullNodePendingAcquires.get(reservedNode); + if (pendingAcquire.getFuture().isCancelled()) { + // discard cancelled pendingAcquire with target node + fullNodePendingAcquires.remove(reservedNode); + verify(fullNodesByQueryId.remove(pendingAcquire.getQueryId(), reservedNode)); + continue; + } + try { + Candidates currentCandidates = selectCandidates(pendingAcquire.getNodeRequirements(), pendingAcquire.getNodeSelector()); + if (currentCandidates.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + if (sharedAllocatedMemory.getOrDefault(reservedNode, 0L) > 0 || allocatedFullNodes.contains(reservedNode)) { + // reserved node is still used - opportunistic check if maybe there is some other empty, not waited for node available + Optional opportunisticNode = currentCandidates.getCandidates().stream() + .filter(node -> !fullNodePendingAcquires.containsKey(node)) + .filter(node -> !allocatedFullNodes.contains(node)) + .filter(node -> sharedAllocatedMemory.getOrDefault(node, 0L) == 0) + .findFirst(); + + if (opportunisticNode.isPresent()) { + fullNodePendingAcquires.remove(reservedNode); + verify(fullNodesByQueryId.remove(pendingAcquire.getQueryId(), reservedNode)); + allocatedFullNodes.add(opportunisticNode.get()); + verify(fullNodesByQueryId.put(pendingAcquire.getQueryId(), opportunisticNode.get())); + assignedNodes.put(pendingAcquire, opportunisticNode.get()); + } + continue; + } + + if (!currentCandidates.getCandidates().contains(reservedNode)) { + // current candidate is gone; move pendingAcquire to detached state + detachedFullNodePendingAcquires.add(pendingAcquire); + fullNodePendingAcquires.remove(reservedNode); + verify(fullNodesByQueryId.remove(pendingAcquire.getQueryId(), reservedNode)); + // trigger one more round of processing immediately + wakeupProcessPendingAcquires(); + continue; + } + + // we are good acquiring reserved full node + allocatedFullNodes.add(reservedNode); + fullNodePendingAcquires.remove(reservedNode); + assignedNodes.put(pendingAcquire, reservedNode); + } + catch (RuntimeException e) { + failures.put(pendingAcquire, e); + fullNodePendingAcquires.remove(reservedNode); + fullNodesByQueryId.remove(pendingAcquire.getQueryId(), reservedNode); + } + } + } + + // complete futures outside of synchronized section + checkState(!Thread.holdsLock(this), "Cannot complete node futures under lock"); + assignedNodes.forEach((pendingAcquire, node) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.set(node); + if (future.isCancelled()) { + releaseFullNode(node, pendingAcquire.getQueryId()); + } + }); + + failures.forEach((pendingAcquire, failure) -> { + SettableFuture future = pendingAcquire.getFuture(); + future.setException(failure); + }); + } + + @PreDestroy + public void stop() + { + stopped.set(true); + executor.shutdownNow(); + } + + public synchronized Optional tryAcquire(NodeRequirements requirements, Candidates candidates, QueryId queryId) + { + if (isFullNode(requirements)) { // todo + return tryAcquireFullNode(candidates, queryId); + } + return tryAcquireSharedNode(candidates, requirements.getMemory().toBytes()); + } + + @VisibleForTesting + synchronized Set getPendingFullNodes() + { + return ImmutableSet.copyOf(fullNodePendingAcquires.keySet()); + } + + private synchronized Optional tryAcquireFullNode(Candidates candidates, QueryId queryId) + { + Collection queryFullNodes = fullNodesByQueryId.get(queryId); + + if (fullNodesCountExceeded(queryFullNodes.size(), candidates.getAllNodesCount())) { + return Optional.empty(); + } + + // select nodes which are not used nor waited for + Optional selectedNode = candidates.getCandidates().stream() + .filter(node -> getNodePoolSize(node).isPresent()) // filter out nodes without memory pool information + .filter(node -> sharedAllocatedMemory.getOrDefault(node, 0L) == 0) + .filter(node -> !allocatedFullNodes.contains(node)) + .filter(node -> !fullNodePendingAcquires.containsKey(node)) + .findFirst(); + + selectedNode.ifPresent(node -> { + allocatedFullNodes.add(node); + fullNodesByQueryId.put(queryId, node); + }); + + return selectedNode; + } + + private boolean fullNodesCountExceeded(int currentCount, int candidatesCount) + { + long threshold = Integer.min(maxAbsoluteFullNodesPerQuery, (int) (candidatesCount * maxFractionFullNodesPerQuery)); + return currentCount >= threshold; + } + + private synchronized Optional tryAcquireSharedNode(Candidates candidates, long memoryLease) + { + Optional selectedNode = candidates.getCandidates().stream() + .filter(node -> getNodePoolSize(node).map(poolSize -> poolSize - sharedAllocatedMemory.getOrDefault(node, 0L) >= memoryLease).orElse(false)) // not enough memory on the node + .filter(node -> !allocatedFullNodes.contains(node)) // node is used exclusively + .filter(node -> !fullNodePendingAcquires.containsKey(node)) // flushing node to get exclusive use + .min(comparing(node -> sharedAllocatedMemory.getOrDefault(node, 0L))); + + selectedNode.ifPresent(node -> sharedAllocatedMemory.merge(node, memoryLease, Long::sum)); + + return selectedNode; + } + + private synchronized PendingAcquire registerPendingAcquire(NodeRequirements requirements, NodeSelector nodeSelector, Candidates candidates, QueryId queryId) + { + PendingAcquire pendingAcquire = new PendingAcquire(requirements, nodeSelector, queryId); + if (isFullNode(requirements)) { + Optional targetNode = findTargetPendingFullNode(queryId, candidates); + + if (targetNode.isEmpty()) { + detachedFullNodePendingAcquires.add(pendingAcquire); + } + else { + verify(!fullNodePendingAcquires.containsKey(targetNode.get())); + verify(!fullNodesByQueryId.get(queryId).contains(targetNode.get())); + fullNodePendingAcquires.put(targetNode.get(), pendingAcquire); + fullNodesByQueryId.put(queryId, targetNode.get()); + } + } + else { + sharedPendingAcquires.add(pendingAcquire); + } + return pendingAcquire; + } + + private Optional findTargetPendingFullNode(QueryId queryId, Candidates candidates) + { + // nodes which are used by are reserved for full-node use for give query + Collection queryFullNodes = fullNodesByQueryId.get(queryId); + if (fullNodesCountExceeded(queryFullNodes.size(), candidates.getAllNodesCount())) { + return Optional.empty(); + } + return candidates.getCandidates().stream() + .filter(not(queryFullNodes::contains)) + .filter(not(fullNodePendingAcquires::containsKey)) + .min(Comparator.comparing(node -> sharedAllocatedMemory.getOrDefault(node, 0L))); + } + + private synchronized void releaseFullNode(InternalNode node, QueryId queryId) + { + verify(allocatedFullNodes.remove(node), "no %s node in allocatedFullNodes", node); + verify(fullNodesByQueryId.remove(queryId, node), "no %s/%s pair in fullNodesByQueryId", queryId, node); + wakeupProcessPendingAcquires(); + } + + private synchronized void releaseSharedNode(InternalNode node, long memoryLease) + { + sharedAllocatedMemory.compute(node, (key, value) -> { + verify(value != null && value >= memoryLease, "invalid memory allocation record %s for node %s", value, key); + long newValue = value - memoryLease; + if (newValue > 0) { + return newValue; + } + return null; // delete entry + }); + wakeupProcessPendingAcquires(); + } + + @Override + public NodeAllocator getNodeAllocator(Session session) + { + return new FullNodeCapableNodeAllocator(session); + } + + private static Candidates selectCandidates(NodeRequirements requirements, NodeSelector nodeSelector) + { + List allNodes = nodeSelector.allNodes(); + return new Candidates( + allNodes.size(), + allNodes.stream() + .filter(node -> requirements.getAddresses().isEmpty() || requirements.getAddresses().contains(node.getHostAndPort())) + .collect(toImmutableList())); + } + + private static class Candidates + { + private final int allNodesCount; + private final List candidates; + + public Candidates(int allNodesCount, List candidates) + { + this.allNodesCount = allNodesCount; + this.candidates = candidates; + } + + public int getAllNodesCount() + { + return allNodesCount; + } + + public List getCandidates() + { + return candidates; + } + + public boolean isEmpty() + { + return candidates.isEmpty(); + } + } + + private class FullNodeCapableNodeAllocator + implements NodeAllocator + { + @GuardedBy("this") + private final Map, NodeSelector> nodeSelectorCache = new HashMap<>(); + private final Session session; + + public FullNodeCapableNodeAllocator(Session session) + { + this.session = requireNonNull(session, "session is null"); + } + + @Override + public NodeLease acquire(NodeRequirements requirements) + { + NodeSelector nodeSelector = nodeSelectorCache.computeIfAbsent(requirements.getCatalogName(), catalogName -> nodeScheduler.createNodeSelector(session, catalogName)); + + Candidates candidates = selectCandidates(requirements, nodeSelector); + if (candidates.isEmpty()) { + throw new TrinoException(NO_NODES_AVAILABLE, "No nodes available to run query"); + } + + QueryId queryId = session.getQueryId(); + + Optional selectedNode = tryAcquire(requirements, candidates, queryId); + + if (selectedNode.isPresent()) { + return new FullNodeCapableNodeLease( + immediateFuture(nodeInfoForNode(selectedNode.get())), + requirements.getMemory().toBytes(), + isFullNode(requirements), + queryId); + } + + PendingAcquire pendingAcquire = registerPendingAcquire(requirements, nodeSelector, candidates, queryId); + return new FullNodeCapableNodeLease( + transform(pendingAcquire.getFuture(), this::nodeInfoForNode, directExecutor()), + requirements.getMemory().toBytes(), + isFullNode(requirements), + queryId); + } + + private NodeInfo nodeInfoForNode(InternalNode node) + { + // todo set memory limit properly + return NodeInfo.unlimitedMemoryNode(node); + } + + @Override + public void close() + { + // nothing to do here. leases should be released by the calling party. + // TODO would be great to be able to validate if it actually happened but close() is called from SqlQueryScheduler code + // and that can be done before all leases are yet returned from running (soon to be failed) tasks. + } + } + + private static class PendingAcquire + { + private final NodeRequirements nodeRequirements; + private final NodeSelector nodeSelector; + private final SettableFuture future; + private final QueryId queryId; + + private PendingAcquire(NodeRequirements nodeRequirements, NodeSelector nodeSelector, QueryId queryId) + { + this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); + this.nodeSelector = requireNonNull(nodeSelector, "nodeSelector is null"); + this.queryId = requireNonNull(queryId, "queryId is null"); + this.future = SettableFuture.create(); + } + + public NodeRequirements getNodeRequirements() + { + return nodeRequirements; + } + + public NodeSelector getNodeSelector() + { + return nodeSelector; + } + + public QueryId getQueryId() + { + return queryId; + } + + public SettableFuture getFuture() + { + return future; + } + + public long getMemoryLease() + { + return nodeRequirements.getMemory().toBytes(); + } + } + + private class FullNodeCapableNodeLease + implements NodeAllocator.NodeLease + { + private final ListenableFuture node; + private final AtomicBoolean released = new AtomicBoolean(); + private final long memoryLease; + private final boolean fullNode; + private final QueryId queryId; + + private FullNodeCapableNodeLease(ListenableFuture node, long memoryLease, boolean fullNode, QueryId queryId) + { + this.node = requireNonNull(node, "node is null"); + this.memoryLease = memoryLease; + this.fullNode = fullNode; + this.queryId = requireNonNull(queryId, "queryId is null"); + } + + @Override + public ListenableFuture getNode() + { + return node; + } + + @Override + public void release() + { + if (released.compareAndSet(false, true)) { + node.cancel(true); + if (node.isDone() && !node.isCancelled()) { + if (fullNode) { + releaseFullNode(getFutureValue(node).getNode(), queryId); + } + else { + releaseSharedNode(getFutureValue(node).getNode(), memoryLease); + } + } + } + else { + throw new IllegalStateException("Node " + node + " already released"); + } + } + } + + private boolean isFullNode(NodeRequirements requirements) + { + return requirements.getMemory().compareTo(FULL_NODE_MEMORY) >= 0; + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeSchedulerConfig.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeSchedulerConfig.java index f6479f3bbc37..0a704e6e26af 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeSchedulerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NodeSchedulerConfig.java @@ -18,6 +18,8 @@ import io.airlift.configuration.DefunctConfig; import io.airlift.configuration.LegacyConfig; +import javax.validation.constraints.DecimalMax; +import javax.validation.constraints.DecimalMin; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; @@ -44,6 +46,9 @@ public enum SplitsBalancingPolicy private boolean optimizedLocalScheduling = true; private SplitsBalancingPolicy splitsBalancingPolicy = SplitsBalancingPolicy.STAGE; private int maxUnacknowledgedSplitsPerTask = 500; + private int maxAbsoluteFullNodesPerQuery = Integer.MAX_VALUE; + private double maxFractionFullNodesPerQuery = 0.5; + private NodeAllocatorType nodeAllocatorType = NodeAllocatorType.FULL_NODE_CAPABLE; @NotNull public NodeSchedulerPolicy getNodeSchedulerPolicy() @@ -163,4 +168,60 @@ public NodeSchedulerConfig setOptimizedLocalScheduling(boolean optimizedLocalSch this.optimizedLocalScheduling = optimizedLocalScheduling; return this; } + + @Config("node-scheduler.max-absolute-full-nodes-per-query") + public NodeSchedulerConfig setMaxAbsoluteFullNodesPerQuery(int maxAbsoluteFullNodesPerQuery) + { + this.maxAbsoluteFullNodesPerQuery = maxAbsoluteFullNodesPerQuery; + return this; + } + + public int getMaxAbsoluteFullNodesPerQuery() + { + return maxAbsoluteFullNodesPerQuery; + } + + @Config("node-scheduler.max-fraction-full-nodes-per-query") + public NodeSchedulerConfig setMaxFractionFullNodesPerQuery(double maxFractionFullNodesPerQuery) + { + this.maxFractionFullNodesPerQuery = maxFractionFullNodesPerQuery; + return this; + } + + @DecimalMin("0.0") + @DecimalMax("1.0") + public double getMaxFractionFullNodesPerQuery() + { + return maxFractionFullNodesPerQuery; + } + + public enum NodeAllocatorType + { + FIXED_COUNT, + FULL_NODE_CAPABLE + } + + @NotNull + public NodeAllocatorType getNodeAllocatorType() + { + return nodeAllocatorType; + } + + @Config("node-scheduler.allocator-type") + public NodeSchedulerConfig setNodeAllocatorType(String nodeAllocatorType) + { + this.nodeAllocatorType = toNodeAllocatorType(nodeAllocatorType); + return this; + } + + private static NodeAllocatorType toNodeAllocatorType(String nodeAllocatorType) + { + switch (nodeAllocatorType.toLowerCase(ENGLISH)) { + case "fixed_count": + return NodeAllocatorType.FIXED_COUNT; + case "full_node_capable": + return NodeAllocatorType.FULL_NODE_CAPABLE; + } + throw new IllegalArgumentException("Unknown node allocator type: " + nodeAllocatorType); + } } diff --git a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java index 6d2c102bb99d..f8e564ee9a1f 100644 --- a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java +++ b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java @@ -510,8 +510,7 @@ public synchronized Map> getWorkerMemoryInfo() { Map> memoryInfo = new HashMap<>(); for (Entry entry : nodes.entrySet()) { - // workerId is of the form "node_identifier [node_host]" - String workerId = entry.getKey() + " [" + entry.getValue().getNode().getHost() + "]"; + String workerId = entry.getKey(); memoryInfo.put(workerId, entry.getValue().getInfo()); } return memoryInfo; diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index fe182266fcc2..9784e666f941 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -61,8 +61,11 @@ import io.trino.execution.resourcegroups.LegacyResourceGroupConfigurationManager; import io.trino.execution.resourcegroups.ResourceGroupManager; import io.trino.execution.scheduler.ConstantPartitionMemoryEstimator; +import io.trino.execution.scheduler.FallbackToFullNodePartitionMemoryEstimator; import io.trino.execution.scheduler.FixedCountNodeAllocatorService; +import io.trino.execution.scheduler.FullNodeCapableNodeAllocatorService; import io.trino.execution.scheduler.NodeAllocatorService; +import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.PartitionMemoryEstimator; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.StageTaskSourceFactory; @@ -131,6 +134,8 @@ import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.FIXED_COUNT; +import static io.trino.execution.scheduler.NodeSchedulerConfig.NodeAllocatorType.FULL_NODE_CAPABLE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; @@ -214,8 +219,20 @@ protected void setup(Binder binder) newExporter(binder).export(ClusterMemoryManager.class).withGeneratedName(); // node allocator - binder.bind(NodeAllocatorService.class).to(FixedCountNodeAllocatorService.class).in(Scopes.SINGLETON); - binder.bind(PartitionMemoryEstimator.class).to(ConstantPartitionMemoryEstimator.class).in(Scopes.SINGLETON); + install(conditionalModule( + NodeSchedulerConfig.class, + config -> FIXED_COUNT == config.getNodeAllocatorType(), + innerBinder -> { + innerBinder.bind(NodeAllocatorService.class).to(FixedCountNodeAllocatorService.class).in(Scopes.SINGLETON); + innerBinder.bind(PartitionMemoryEstimator.class).to(ConstantPartitionMemoryEstimator.class).in(Scopes.SINGLETON); + })); + install(conditionalModule( + NodeSchedulerConfig.class, + config -> FULL_NODE_CAPABLE == config.getNodeAllocatorType(), + innerBinder -> { + innerBinder.bind(NodeAllocatorService.class).to(FullNodeCapableNodeAllocatorService.class).in(Scopes.SINGLETON); + innerBinder.bind(PartitionMemoryEstimator.class).to(FallbackToFullNodePartitionMemoryEstimator.class).in(Scopes.SINGLETON); + })); // node monitor binder.bind(ClusterSizeMonitor.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestNodeSchedulerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestNodeSchedulerConfig.java index d2699e30e0bb..0ceab303f0b1 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestNodeSchedulerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestNodeSchedulerConfig.java @@ -38,7 +38,10 @@ public void testDefaults() .setMaxUnacknowledgedSplitsPerTask(500) .setIncludeCoordinator(true) .setSplitsBalancingPolicy(NodeSchedulerConfig.SplitsBalancingPolicy.STAGE) - .setOptimizedLocalScheduling(true)); + .setOptimizedLocalScheduling(true) + .setMaxAbsoluteFullNodesPerQuery(Integer.MAX_VALUE) + .setMaxFractionFullNodesPerQuery(0.5) + .setNodeAllocatorType("full_node_capable")); } @Test @@ -53,6 +56,9 @@ public void testExplicitPropertyMappings() .put("node-scheduler.max-unacknowledged-splits-per-task", "501") .put("node-scheduler.splits-balancing-policy", "node") .put("node-scheduler.optimized-local-scheduling", "false") + .put("node-scheduler.max-absolute-full-nodes-per-query", "17") + .put("node-scheduler.max-fraction-full-nodes-per-query", "0.3") + .put("node-scheduler.allocator-type", "fixed_count") .buildOrThrow(); NodeSchedulerConfig expected = new NodeSchedulerConfig() @@ -63,7 +69,10 @@ public void testExplicitPropertyMappings() .setMaxUnacknowledgedSplitsPerTask(501) .setMinCandidates(11) .setSplitsBalancingPolicy(NODE) - .setOptimizedLocalScheduling(false); + .setOptimizedLocalScheduling(false) + .setMaxAbsoluteFullNodesPerQuery(17) + .setMaxFractionFullNodesPerQuery(0.3) + .setNodeAllocatorType("fixed_count"); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFullNodeCapableNodeAllocator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFullNodeCapableNodeAllocator.java new file mode 100644 index 000000000000..f0883beaad4c --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFullNodeCapableNodeAllocator.java @@ -0,0 +1,698 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.client.NodeVersion; +import io.trino.connector.CatalogName; +import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; +import io.trino.memory.MemoryInfo; +import io.trino.metadata.InternalNode; +import io.trino.spi.HostAddress; +import io.trino.spi.QueryId; +import io.trino.spi.memory.MemoryPoolInfo; +import io.trino.testing.assertions.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.concurrent.MoreFutures.getFutureValue; +import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static io.trino.execution.scheduler.FallbackToFullNodePartitionMemoryEstimator.FULL_NODE_MEMORY; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertTrue; + +// uses mutable state +@Test(singleThreaded = true) +public class TestFullNodeCapableNodeAllocator +{ + private static final Session Q1_SESSION = testSessionBuilder().setQueryId(QueryId.valueOf("q1")).build(); + private static final Session Q2_SESSION = testSessionBuilder().setQueryId(QueryId.valueOf("q2")).build(); + + private static final HostAddress NODE_1_ADDRESS = HostAddress.fromParts("127.0.0.1", 8080); + private static final HostAddress NODE_2_ADDRESS = HostAddress.fromParts("127.0.0.1", 8081); + private static final HostAddress NODE_3_ADDRESS = HostAddress.fromParts("127.0.0.1", 8082); + private static final HostAddress NODE_4_ADDRESS = HostAddress.fromParts("127.0.0.1", 8083); + + private static final InternalNode NODE_1 = new InternalNode("node-1", URI.create("local://" + NODE_1_ADDRESS), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_2 = new InternalNode("node-2", URI.create("local://" + NODE_2_ADDRESS), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_3 = new InternalNode("node-3", URI.create("local://" + NODE_3_ADDRESS), NodeVersion.UNKNOWN, false); + private static final InternalNode NODE_4 = new InternalNode("node-4", URI.create("local://" + NODE_4_ADDRESS), NodeVersion.UNKNOWN, false); + + private static final CatalogName CATALOG_1 = new CatalogName("catalog1"); + private static final CatalogName CATALOG_2 = new CatalogName("catalog2"); + + private static final NodeRequirements NO_REQUIREMENTS = new NodeRequirements(Optional.empty(), Set.of(), DataSize.of(32, GIGABYTE)); + private static final NodeRequirements SHARED_NODE_CATALOG_1_REQUIREMENTS = new NodeRequirements(Optional.of(CATALOG_1), Set.of(), DataSize.of(32, GIGABYTE)); + private static final NodeRequirements FULL_NODE_REQUIREMENTS = new NodeRequirements(Optional.empty(), Set.of(), FULL_NODE_MEMORY); + private static final NodeRequirements FULL_NODE_1_REQUIREMENTS = new NodeRequirements(Optional.empty(), Set.of(NODE_1_ADDRESS), FULL_NODE_MEMORY); + private static final NodeRequirements FULL_NODE_2_REQUIREMENTS = new NodeRequirements(Optional.empty(), Set.of(NODE_2_ADDRESS), FULL_NODE_MEMORY); + private static final NodeRequirements FULL_NODE_3_REQUIREMENTS = new NodeRequirements(Optional.empty(), Set.of(NODE_3_ADDRESS), FULL_NODE_MEMORY); + private static final NodeRequirements FULL_NODE_CATALOG_1_REQUIREMENTS = new NodeRequirements(Optional.of(CATALOG_1), Set.of(), FULL_NODE_MEMORY); + private static final NodeRequirements FULL_NODE_CATALOG_2_REQUIREMENTS = new NodeRequirements(Optional.of(CATALOG_2), Set.of(), FULL_NODE_MEMORY); + + // none of the tests should require periodic execution of routine which processes pending acquisitions + private static final long TEST_TIMEOUT = FullNodeCapableNodeAllocatorService.PROCESS_PENDING_ACQUIRES_DELAY_SECONDS * 1000 / 2; + + private FullNodeCapableNodeAllocatorService nodeAllocatorService; + + private void setupNodeAllocatorService(TestingNodeSupplier testingNodeSupplier, int maxFullNodesPerQuery) + { + shutdownNodeAllocatorService(); // just in case + + MemoryInfo memoryInfo = new MemoryInfo(4, new MemoryPoolInfo(DataSize.of(64, GIGABYTE).toBytes(), 0, 0, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of())); + + Map> workerMemoryInfos = ImmutableMap.of( + NODE_1.getNodeIdentifier(), Optional.of(memoryInfo), + NODE_2.getNodeIdentifier(), Optional.of(memoryInfo), + NODE_3.getNodeIdentifier(), Optional.of(memoryInfo), + NODE_4.getNodeIdentifier(), Optional.of(memoryInfo)); + nodeAllocatorService = new FullNodeCapableNodeAllocatorService( + new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, testingNodeSupplier)), + () -> workerMemoryInfos, + maxFullNodesPerQuery, + 1.0); + nodeAllocatorService.start(); + } + + @AfterMethod(alwaysRun = true) + public void shutdownNodeAllocatorService() + { + if (nodeAllocatorService != null) { + nodeAllocatorService.stop(); + } + nodeAllocatorService = null; + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateSharedSimple() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // first two allocation should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire2); + // and different nodes should be assigned for each + assertThat(Set.of(acquire1.getNode().get().getNode(), acquire2.getNode().get().getNode())).containsExactlyInAnyOrder(NODE_1, NODE_2); + + // same for subsequent two allocation (each task requires 32GB and we have 2 nodes with 64GB each) + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire3); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire4); + assertThat(Set.of(acquire3.getNode().get().getNode(), acquire4.getNode().get().getNode())).containsExactlyInAnyOrder(NODE_1, NODE_2); + + // 5th allocation should block + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertNotAcquired(acquire5); + + // release acquire2 which uses + acquire2.release(); + assertEventually(() -> { + // we need to wait as pending acquires are processed asynchronously + assertAcquired(acquire5); + assertEquals(acquire5.getNode().get().getNode(), acquire2.getNode().get().getNode()); + }); + + // try to acquire one more node (should block) + NodeAllocator.NodeLease acquire6 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertNotAcquired(acquire6); + + // add new node + nodeSupplier.addNode(NODE_3, ImmutableList.of()); + // TODO: make FullNodeCapableNodeAllocatorService react on new node added automatically + nodeAllocatorService.wakeupProcessPendingAcquires(); + + // new node should be assigned + assertEventually(() -> { + assertAcquired(acquire6); + assertEquals(acquire6.getNode().get().getNode(), NODE_3); + }); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateSharedReleaseBeforeAcquired() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // first two allocation should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire2, NODE_1); + + // another two should block + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertNotAcquired(acquire3); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertNotAcquired(acquire4); + + // releasing a blocked one should not unblock anything + acquire3.release(); + assertNotAcquired(acquire4); + + // releasing an acquired one should unblock one which is still blocked + acquire2.release(); + assertEventually(() -> assertAcquired(acquire4, NODE_1)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testNoSharedNodeAvailable() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // request a node with specific catalog (not present) + + assertThatThrownBy(() -> nodeAllocator.acquire(SHARED_NODE_CATALOG_1_REQUIREMENTS.withMemory(DataSize.of(64, GIGABYTE)))) + .hasMessage("No nodes available to run query"); + + // add node with specific catalog + nodeSupplier.addNode(NODE_2, ImmutableList.of(CATALOG_1)); + + // we should be able to acquire the node now + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(SHARED_NODE_CATALOG_1_REQUIREMENTS.withMemory(DataSize.of(64, GIGABYTE))); + assertAcquired(acquire1, NODE_2); + + // acquiring one more should block (only one acquire fits a node as we request 64GB) + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(SHARED_NODE_CATALOG_1_REQUIREMENTS.withMemory(DataSize.of(64, GIGABYTE))); + assertNotAcquired(acquire2); + + // remove node with catalog + nodeSupplier.removeNode(NODE_2); + // TODO: make FullNodeCapableNodeAllocatorService react on node removed automatically + nodeAllocatorService.wakeupProcessPendingAcquires(); + + // pending acquire2 should be completed now but with an exception + assertEventually(() -> { + assertFalse(acquire2.getNode().isCancelled()); + assertTrue(acquire2.getNode().isDone()); + assertThatThrownBy(() -> getFutureValue(acquire2.getNode())) + .hasMessage("No nodes available to run query"); + }); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testRemoveAcquiredSharedNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + + // remove acquired node + nodeSupplier.removeNode(NODE_1); + + // we should still be able to release lease for removed node + acquire1.release(); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullSimple() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2)); + setupNodeAllocatorService(nodeSupplier, 3); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // allocate 2 full nodes should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire2); + + // trying to allocate third full node should block + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire3); + + // third acquisition should unblock if one of old ones is released + acquire1.release(); + assertEventually(() -> { + assertAcquired(acquire3); + assertEquals(acquire3.getNode().get().getNode(), acquire1.getNode().get().getNode()); + }); + + // both nodes are used exclusively so we should no be able to acquire shared node + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertNotAcquired(acquire4); + + // shared acquisition should unblock if one of full ones is released + acquire2.release(); + assertEventually(() -> { + assertAcquired(acquire4); + assertEquals(acquire4.getNode().get().getNode(), acquire2.getNode().get().getNode()); + }); + + // shared acquisition should block full acquisition + NodeAllocator.NodeLease acquire5 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire5); + + // and when shared acquisition is gone full node should be acquired + acquire4.release(); + assertEventually(() -> { + assertAcquired(acquire5); + assertEquals(acquire5.getNode().get().getNode(), acquire4.getNode().get().getNode()); + }); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullReleaseBeforeAcquired() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // first allocation should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + + // another two should block + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire2); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire3); + + // releasing a blocked one should not unblock anything + acquire2.release(); + assertNotAcquired(acquire3); + + // releasing one acquired one should unblock one which is still blocked + acquire1.release(); + assertEventually(() -> assertAcquired(acquire3, NODE_1)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullWithQueryLimit() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2, NODE_3)); + setupNodeAllocatorService(nodeSupplier, 2); + + try (NodeAllocator q1NodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION); + NodeAllocator q2NodeAllocator = nodeAllocatorService.getNodeAllocator(Q2_SESSION)) { + // allocate 2 full nodes for Q1 should not block + NodeAllocator.NodeLease q1Acquire1 = q1NodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(q1Acquire1); + NodeAllocator.NodeLease q1Acquire2 = q1NodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(q1Acquire2); + + // third allocation for Q1 should block even though we have 3 nodes available + NodeAllocator.NodeLease q1Acquire3 = q1NodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(q1Acquire3); + + // we should still be able to acquire full node for another query + NodeAllocator.NodeLease q2Acquire1 = q2NodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(q2Acquire1); + + // when we release one of the nodes for Q1 pending q1Acquire3 should unblock + q1Acquire1.release(); + assertEventually(() -> { + assertAcquired(q1Acquire3); + assertEquals(q1Acquire3.getNode().get().getNode(), q1Acquire1.getNode().get().getNode()); + }); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullOpportunistic() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2)); + setupNodeAllocatorService(nodeSupplier, 2); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // allocate both nodes as shared + NodeAllocator.NodeLease shared1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(shared1); + NodeAllocator.NodeLease shared2 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(shared2); + + // try to allocate 2 full nodes - will block as both nodes in cluster are used + NodeAllocator.NodeLease full1 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(full1); + NodeAllocator.NodeLease full2 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(full2); + + // add new node to the cluster + nodeSupplier.addNode(NODE_3, ImmutableList.of()); + // TODO: make FullNodeCapableNodeAllocatorService react on new node added automatically + nodeAllocatorService.wakeupProcessPendingAcquires(); + + // one of the full1/full2 should be not blocked now + assertEventually(() -> assertTrue(full1.getNode().isDone() ^ full2.getNode().isDone(), "exactly one of full1/full2 should be unblocked")); + NodeAllocator.NodeLease fullBlocked = full1.getNode().isDone() ? full2 : full1; + NodeAllocator.NodeLease fullNotBlocked = full1.getNode().isDone() ? full1 : full2; + + // and when unblocked one releases node the other should grab it + fullNotBlocked.release(); + nodeAllocatorService.wakeupProcessPendingAcquires(); + assertEventually(() -> assertAcquired(fullBlocked)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullWithAddressRequirements() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2, NODE_3)); + + setupNodeAllocatorService(nodeSupplier, 2); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_1_REQUIREMENTS); + assertAcquired(acquire1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_2_REQUIREMENTS); + assertAcquired(acquire2); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_3_REQUIREMENTS); + assertNotAcquired(acquire3); + + acquire1.release(); + assertEventually(() -> assertAcquired(acquire3)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullWithCatalogRequirements() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(nodesMapBuilder() + .put(NODE_1, ImmutableList.of(CATALOG_1)) + .put(NODE_2, ImmutableList.of(CATALOG_1)) + .put(NODE_3, ImmutableList.of(CATALOG_2)) + .buildOrThrow()); + + setupNodeAllocatorService(nodeSupplier, 2); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // we have 3 nodes available and per-query limit set to 2 but only 1 node that exposes CATALOG_2 + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_CATALOG_2_REQUIREMENTS); + assertAcquired(acquire1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_CATALOG_2_REQUIREMENTS); + assertNotAcquired(acquire2); + + // releasing CATALOG_2 node allows pending lease to acquire it + acquire1.release(); + assertEventually(() -> assertAcquired(acquire2)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullWithQueryLimitAndCatalogRequirements() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(nodesMapBuilder() + .put(NODE_1, ImmutableList.of(CATALOG_1)) + .put(NODE_2, ImmutableList.of(CATALOG_1)) + .put(NODE_3, ImmutableList.of(CATALOG_2)) + .put(NODE_4, ImmutableList.of(CATALOG_2)) + .buildOrThrow()); + + setupNodeAllocatorService(nodeSupplier, 2); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // allocate 2 full nodes for Q1 should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_CATALOG_1_REQUIREMENTS); + assertAcquired(acquire1); + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_CATALOG_2_REQUIREMENTS); + assertAcquired(acquire2); + + // another allocation for CATALOG_1 will block (per query limit is 2) + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_CATALOG_1_REQUIREMENTS); + assertNotAcquired(acquire3); + + // releasing CATALOG_2 node for query will unblock pending lease for CATALOG_1 + acquire2.release(); + assertEventually(() -> assertAcquired(acquire3)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullNodeReleaseBeforeAcquiredWaitingOnMaxFullNodesPerQuery() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // first full allocation should not block + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + + // next two should block (maxFullNodesPerQuery == 1) + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire2); + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire3); + + // releasing a blocked one should not unblock anything + acquire2.release(); + assertNotAcquired(acquire3); + + // releasing an acquired one should unblock one which is still blocked + acquire1.release(); + assertEventually(() -> assertAcquired(acquire3, NODE_1)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testAllocateFullNodeReleaseBeforeAcquiredWaitingOnOtherNodesUsed() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 100); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // allocate NODE_1 in shared mode + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + + // add one more node + nodeSupplier.addNode(NODE_2, ImmutableList.of()); + + // first full allocation should not block + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire2, NODE_2); + + // next two should block (all nodes used) + NodeAllocator.NodeLease acquire3 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire3); + NodeAllocator.NodeLease acquire4 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(acquire4); + + // releasing a blocked one should not unblock anything + acquire3.release(); + assertNotAcquired(acquire4); + + // releasing node acquired in shared move one should unblock one which is still blocked + acquire1.release(); + assertEventually(() -> assertAcquired(acquire4, NODE_1)); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testRemoveAcquiredFullNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertAcquired(acquire1, NODE_1); + + // remove acquired node + nodeSupplier.removeNode(NODE_1); + + // we should still be able to release lease for removed node + acquire1.release(); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testNoFullNodeAvailable() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1)); + setupNodeAllocatorService(nodeSupplier, 100); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + // request a full node with specific catalog (not present) + + assertThatThrownBy(() -> nodeAllocator.acquire(FULL_NODE_CATALOG_1_REQUIREMENTS)) + .hasMessage("No nodes available to run query"); + + // add node with specific catalog + nodeSupplier.addNode(NODE_2, ImmutableList.of(CATALOG_1)); + + // we should be able to acquire the node now + NodeAllocator.NodeLease acquire1 = nodeAllocator.acquire(FULL_NODE_CATALOG_1_REQUIREMENTS); + assertAcquired(acquire1, NODE_2); + + // acquiring one more should block (all nodes with catalog already used) + NodeAllocator.NodeLease acquire2 = nodeAllocator.acquire(FULL_NODE_CATALOG_1_REQUIREMENTS); + assertNotAcquired(acquire2); + + // remove node with catalog + nodeSupplier.removeNode(NODE_2); + // TODO: make FullNodeCapableNodeAllocatorService react on node removed automatically + nodeAllocatorService.wakeupProcessPendingAcquires(); + + // pending acquire2 should be completed now but with an exception + assertEventually(() -> { + assertFalse(acquire2.getNode().isCancelled()); + assertTrue(acquire2.getNode().isDone()); + assertThatThrownBy(() -> getFutureValue(acquire2.getNode())) + .hasMessage("No nodes available to run query"); + }); + } + } + + @Test(timeOut = TEST_TIMEOUT) + public void testRemoveAssignedFullNode() + throws Exception + { + TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(basicNodesMap(NODE_1, NODE_2)); + setupNodeAllocatorService(nodeSupplier, 1); + + try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(Q1_SESSION)) { + NodeAllocator.NodeLease sharedAcquire1 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(sharedAcquire1); + NodeAllocator.NodeLease sharedAcquire2 = nodeAllocator.acquire(NO_REQUIREMENTS); + assertAcquired(sharedAcquire2); + + InternalNode nodeAcquired1 = sharedAcquire1.getNode().get().getNode(); + InternalNode nodeAcquired2 = sharedAcquire2.getNode().get().getNode(); + assertNotEquals(nodeAcquired1, nodeAcquired2); + + // try to acquire full node; should not happen + NodeAllocator.NodeLease fullAcquire = nodeAllocator.acquire(FULL_NODE_REQUIREMENTS); + assertNotAcquired(fullAcquire); + + Set pendingFullNodes = nodeAllocatorService.getPendingFullNodes(); + InternalNode pendingFullNode = Iterables.getOnlyElement(pendingFullNodes); + + // remove assigned node and release shared allocation for it; full node acquire still should not be fulfilled + nodeSupplier.removeNode(pendingFullNode); + sharedAcquire1.release(); + assertNotAcquired(fullAcquire); + + // release remaining node in the cluster + sharedAcquire2.release(); + + // full node should be fulfilled now + assertEventually(() -> { + // we need to wait as pending acquires are processed asynchronously + assertAcquired(fullAcquire, nodeAcquired2); + }); + } + } + + private Map> basicNodesMap(InternalNode... nodes) + { + return Arrays.stream(nodes) + .collect(toImmutableMap( + node -> node, + node -> ImmutableList.of())); + } + + private ImmutableMap.Builder> nodesMapBuilder() + { + return ImmutableMap.builder(); + } + + private void assertAcquired(NodeAllocator.NodeLease lease, InternalNode node) + throws Exception + { + assertAcquired(lease, Optional.of(node)); + } + + private void assertAcquired(NodeAllocator.NodeLease lease) + throws Exception + { + assertAcquired(lease, Optional.empty()); + } + + private void assertAcquired(NodeAllocator.NodeLease lease, Optional expectedNode) + throws Exception + { + assertFalse(lease.getNode().isCancelled(), "node lease cancelled"); + assertTrue(lease.getNode().isDone(), "node lease not acquired"); + if (expectedNode.isPresent()) { + assertEquals(lease.getNode().get().getNode(), expectedNode.get()); + } + } + + private void assertNotAcquired(NodeAllocator.NodeLease lease) + { + assertFalse(lease.getNode().isCancelled(), "node lease cancelled"); + assertFalse(lease.getNode().isDone(), "node lease acquired"); + // enforce pending acquires processing and check again + nodeAllocatorService.processPendingAcquires(); + assertFalse(lease.getNode().isCancelled(), "node lease cancelled"); + assertFalse(lease.getNode().isDone(), "node lease acquired"); + } + + private static void assertEventually(ThrowingRunnable assertion) + { + Assert.assertEventually(() -> { + try { + assertion.run(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + interface ThrowingRunnable + { + void run() throws Exception; + } +} From cebfae3ae1eb0857ae04b18de0231cc58a83f882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Mon, 17 Jan 2022 22:44:13 +0100 Subject: [PATCH 11/11] Separate retries configuration for queries and tasks --- .../io/trino/SystemSessionProperties.java | 34 ++++++++++++--- .../trino/execution/QueryManagerConfig.java | 41 ++++++++++++++++--- .../FaultTolerantStageScheduler.java | 20 ++++++--- .../scheduler/SqlQueryScheduler.java | 23 +++++++---- .../execution/TestQueryManagerConfig.java | 12 ++++-- .../TestFaultTolerantStageScheduler.java | 1 + .../testing/BaseFailureRecoveryTest.java | 3 +- 7 files changed, 105 insertions(+), 29 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index 7b3c98dcf3db..ca86dc07c8a9 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -145,7 +145,9 @@ public final class SystemSessionProperties public static final String INCREMENTAL_HASH_ARRAY_LOAD_FACTOR_ENABLED = "incremental_hash_array_load_factor_enabled"; public static final String MAX_PARTIAL_TOP_N_MEMORY = "max_partial_top_n_memory"; public static final String RETRY_POLICY = "retry_policy"; - public static final String RETRY_ATTEMPTS = "retry_attempts"; + public static final String QUERY_RETRY_ATTEMPTS = "query_retry_attempts"; + public static final String TASK_RETRY_ATTEMPTS_OVERALL = "task_retry_attempts_overall"; + public static final String TASK_RETRY_ATTEMPTS_PER_TASK = "task_retry_attempts_per_task"; public static final String RETRY_INITIAL_DELAY = "retry_initial_delay"; public static final String RETRY_MAX_DELAY = "retry_max_delay"; public static final String HIDE_INACCESSIBLE_COLUMNS = "hide_inaccessible_columns"; @@ -683,9 +685,19 @@ public SystemSessionProperties( queryManagerConfig.getRetryPolicy(), false), integerProperty( - RETRY_ATTEMPTS, - "Maximum number of retry attempts", - queryManagerConfig.getRetryAttempts(), + QUERY_RETRY_ATTEMPTS, + "Maximum number of query retry attempts", + queryManagerConfig.getQueryRetryAttempts(), + false), + integerProperty( + TASK_RETRY_ATTEMPTS_OVERALL, + "Maximum number of task retry attempts overall", + queryManagerConfig.getTaskRetryAttemptsOverall(), + false), + integerProperty( + TASK_RETRY_ATTEMPTS_PER_TASK, + "Maximum number of task retry attempts per single task", + queryManagerConfig.getTaskRetryAttemptsPerTask(), false), durationProperty( RETRY_INITIAL_DELAY, @@ -1264,9 +1276,19 @@ public static RetryPolicy getRetryPolicy(Session session) return retryPolicy; } - public static int getRetryAttempts(Session session) + public static int getQueryRetryAttempts(Session session) + { + return session.getSystemProperty(QUERY_RETRY_ATTEMPTS, Integer.class); + } + + public static int getTaskRetryAttemptsOverall(Session session) + { + return session.getSystemProperty(TASK_RETRY_ATTEMPTS_OVERALL, Integer.class); + } + + public static int getTaskRetryAttemptsPerTask(Session session) { - return session.getSystemProperty(RETRY_ATTEMPTS, Integer.class); + return session.getSystemProperty(TASK_RETRY_ATTEMPTS_PER_TASK, Integer.class); } public static Duration getRetryInitialDelay(Session session) diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index 70320021dd43..ab59b8204e08 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -76,7 +76,9 @@ public class QueryManagerConfig private Duration requiredWorkersMaxWait = new Duration(5, TimeUnit.MINUTES); private RetryPolicy retryPolicy = RetryPolicy.NONE; - private int retryAttempts = 4; + private int queryRetryAttempts = 4; + private int taskRetryAttemptsPerTask = 2; + private int taskRetryAttemptsOverall = Integer.MAX_VALUE; private Duration retryInitialDelay = new Duration(10, SECONDS); private Duration retryMaxDelay = new Duration(1, MINUTES); @@ -414,15 +416,42 @@ public QueryManagerConfig setRetryPolicy(RetryPolicy retryPolicy) } @Min(0) - public int getRetryAttempts() + public int getQueryRetryAttempts() { - return retryAttempts; + return queryRetryAttempts; } - @Config("retry-attempts") - public QueryManagerConfig setRetryAttempts(int retryAttempts) + @Config("query-retry-attempts") + @LegacyConfig("retry-attempts") + public QueryManagerConfig setQueryRetryAttempts(int queryRetryAttempts) { - this.retryAttempts = retryAttempts; + this.queryRetryAttempts = queryRetryAttempts; + return this; + } + + @Min(0) + public int getTaskRetryAttemptsOverall() + { + return taskRetryAttemptsOverall; + } + + @Config("task-retry-attempts-overall") + public QueryManagerConfig setTaskRetryAttemptsOverall(int taskRetryAttemptsOverall) + { + this.taskRetryAttemptsOverall = taskRetryAttemptsOverall; + return this; + } + + @Min(0) + public int getTaskRetryAttemptsPerTask() + { + return taskRetryAttemptsPerTask; + } + + @Config("task-retry-attempts-per-task") + public QueryManagerConfig setTaskRetryAttemptsPerTask(int taskRetryAttemptsPerTask) + { + this.taskRetryAttemptsPerTask = taskRetryAttemptsPerTask; return this; } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java index 305b472dba52..3d376b3e19cf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java @@ -95,6 +95,7 @@ public class FaultTolerantStageScheduler private final NodeAllocator nodeAllocator; private final TaskDescriptorStorage taskDescriptorStorage; private final PartitionMemoryEstimator partitionMemoryEstimator; + private final int maxRetryAttemptsPerTask; private final TaskLifecycleListener taskLifecycleListener; // empty when the results are consumed via a direct exchange @@ -130,7 +131,9 @@ public class FaultTolerantStageScheduler @GuardedBy("this") private final Set finishedPartitions = new HashSet<>(); @GuardedBy("this") - private int remainingRetryAttempts; + private int remainingRetryAttemptsOverall; + @GuardedBy("this") + private final Map remainingAttemptsPerTask = new HashMap<>(); @GuardedBy("this") private Map partitionMemoryRequirements = new HashMap<>(); @@ -153,7 +156,8 @@ public FaultTolerantStageScheduler( Map sourceExchanges, Optional sourceBucketToPartitionMap, Optional sourceBucketNodeMap, - int retryAttempts) + int taskRetryAttemptsOverall, + int taskRetryAttemptsPerTask) { checkArgument(!stage.getFragment().getStageExecutionDescriptor().isStageGroupedExecution(), "grouped execution is expected to be disabled"); @@ -170,8 +174,9 @@ public FaultTolerantStageScheduler( this.sourceExchanges = ImmutableMap.copyOf(requireNonNull(sourceExchanges, "sourceExchanges is null")); this.sourceBucketToPartitionMap = requireNonNull(sourceBucketToPartitionMap, "sourceBucketToPartitionMap is null"); this.sourceBucketNodeMap = requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null"); - checkArgument(retryAttempts >= 0, "retryAttempts must be greater than or equal to 0: %s", retryAttempts); - this.remainingRetryAttempts = retryAttempts; + checkArgument(taskRetryAttemptsOverall >= 0, "taskRetryAttemptsOverall must be greater than or equal to 0: %s", taskRetryAttemptsOverall); + this.remainingRetryAttemptsOverall = taskRetryAttemptsOverall; + this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask; } public StageId getStageId() @@ -531,8 +536,11 @@ private void updateTaskStatus(TaskStatus taskStatus, Optional 0 && (errorCode == null || errorCode.getType() != USER_ERROR)) { - remainingRetryAttempts--; + + int taskRemainingAttempts = remainingAttemptsPerTask.getOrDefault(partitionId, maxRetryAttemptsPerTask); + if (remainingRetryAttemptsOverall > 0 && taskRemainingAttempts > 0 && (errorCode == null || errorCode.getType() != USER_ERROR)) { + remainingRetryAttemptsOverall--; + remainingAttemptsPerTask.put(partitionId, taskRemainingAttempts - 1); // update memory limits for next attempt MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index 92005b6b1119..e815d1a184f3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -123,10 +123,12 @@ import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.SystemSessionProperties.getConcurrentLifespansPerNode; import static io.trino.SystemSessionProperties.getHashPartitionCount; -import static io.trino.SystemSessionProperties.getRetryAttempts; +import static io.trino.SystemSessionProperties.getQueryRetryAttempts; import static io.trino.SystemSessionProperties.getRetryInitialDelay; import static io.trino.SystemSessionProperties.getRetryMaxDelay; import static io.trino.SystemSessionProperties.getRetryPolicy; +import static io.trino.SystemSessionProperties.getTaskRetryAttemptsOverall; +import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; import static io.trino.SystemSessionProperties.getWriterMinSize; import static io.trino.connector.CatalogName.isInternalSystemConnector; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; @@ -193,7 +195,9 @@ public class SqlQueryScheduler private final CoordinatorStagesScheduler coordinatorStagesScheduler; private final RetryPolicy retryPolicy; - private final int maxRetryAttempts; + private final int maxQueryRetryAttempts; + private final int maxTaskRetryAttemptsOverall; + private final int maxTaskRetryAttemptsPerTask; private final AtomicInteger currentAttempt = new AtomicInteger(); private final Duration retryInitialDelay; private final Duration retryMaxDelay; @@ -270,7 +274,9 @@ public SqlQueryScheduler( coordinatorTaskManager); retryPolicy = getRetryPolicy(queryStateMachine.getSession()); - maxRetryAttempts = getRetryAttempts(queryStateMachine.getSession()); + maxQueryRetryAttempts = getQueryRetryAttempts(queryStateMachine.getSession()); + maxTaskRetryAttemptsOverall = getTaskRetryAttemptsOverall(queryStateMachine.getSession()); + maxTaskRetryAttemptsPerTask = getTaskRetryAttemptsPerTask(queryStateMachine.getSession()); retryInitialDelay = getRetryInitialDelay(queryStateMachine.getSession()); retryMaxDelay = getRetryMaxDelay(queryStateMachine.getSession()); } @@ -344,7 +350,8 @@ private synchronized Optional createDistributedStage exchangeManager, nodePartitioningManager, coordinatorStagesScheduler.getTaskLifecycleListener(), - maxRetryAttempts, + maxTaskRetryAttemptsOverall, + maxTaskRetryAttemptsPerTask, schedulerExecutor, schedulerStats, nodeAllocatorService, @@ -422,7 +429,7 @@ else if (state == DistributedStagesSchedulerState.CANCELED) { private boolean shouldRetry(ErrorCode errorCode) { - return retryPolicy == RetryPolicy.QUERY && currentAttempt.get() < maxRetryAttempts && isRetryableErrorCode(errorCode); + return retryPolicy == RetryPolicy.QUERY && currentAttempt.get() < maxQueryRetryAttempts && isRetryableErrorCode(errorCode); } private static boolean isRetryableErrorCode(ErrorCode errorCode) @@ -1745,7 +1752,8 @@ public static FaultTolerantDistributedStagesScheduler create( ExchangeManager exchangeManager, NodePartitioningManager nodePartitioningManager, TaskLifecycleListener coordinatorTaskLifecycleListener, - int retryAttempts, + int taskRetryAttemptsOverall, + int taskRetryAttemptsPerTask, ScheduledExecutorService scheduledExecutorService, SplitSchedulerStats schedulerStats, NodeAllocatorService nodeAllocatorService, @@ -1818,7 +1826,8 @@ public static FaultTolerantDistributedStagesScheduler create( sourceExchanges.buildOrThrow(), inputBucketToPartition.getBucketToPartitionMap(), inputBucketToPartition.getBucketNodeMap(), - retryAttempts); + taskRetryAttemptsOverall, + taskRetryAttemptsPerTask); schedulers.add(scheduler); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index 017db61356bc..8efcf3572aa3 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -63,7 +63,9 @@ public void testDefaults() .setRequiredWorkers(1) .setRequiredWorkersMaxWait(new Duration(5, MINUTES)) .setRetryPolicy(RetryPolicy.NONE) - .setRetryAttempts(4) + .setQueryRetryAttempts(4) + .setTaskRetryAttemptsOverall(Integer.MAX_VALUE) + .setTaskRetryAttemptsPerTask(2) .setRetryInitialDelay(new Duration(10, SECONDS)) .setRetryMaxDelay(new Duration(1, MINUTES)) .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(1, GIGABYTE)) @@ -101,7 +103,9 @@ public void testExplicitPropertyMappings() .put("query-manager.required-workers", "333") .put("query-manager.required-workers-max-wait", "33m") .put("retry-policy", "QUERY") - .put("retry-attempts", "0") + .put("query-retry-attempts", "0") + .put("task-retry-attempts-overall", "17") + .put("task-retry-attempts-per-task", "9") .put("retry-initial-delay", "1m") .put("retry-max-delay", "1h") .put("fault-tolerant-execution-target-task-input-size", "222MB") @@ -136,7 +140,9 @@ public void testExplicitPropertyMappings() .setRequiredWorkers(333) .setRequiredWorkersMaxWait(new Duration(33, MINUTES)) .setRetryPolicy(RetryPolicy.QUERY) - .setRetryAttempts(0) + .setQueryRetryAttempts(0) + .setTaskRetryAttemptsOverall(17) + .setTaskRetryAttemptsPerTask(9) .setRetryInitialDelay(new Duration(1, MINUTES)) .setRetryMaxDelay(new Duration(1, HOURS)) .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(222, MEGABYTE)) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java index ab74f4ef8b70..2a92a88efc62 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java @@ -516,6 +516,7 @@ private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( sourceExchanges, Optional.empty(), Optional.empty(), + retryAttempts, retryAttempts); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java index b74717f2ac99..b8d1ee7128a8 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java @@ -102,7 +102,8 @@ protected final QueryRunner createQueryRunner() .put("exchange.max-error-duration", MAX_ERROR_DURATION.toString()) .put("retry-policy", retryPolicy.toString()) .put("retry-initial-delay", "0s") - .put("retry-attempts", "1") + .put("query-retry-attempts", "1") + .put("task-retry-attempts-overall", "1") .put("failure-injection.request-timeout", new Duration(REQUEST_TIMEOUT.toMillis() * 2, MILLISECONDS).toString()) // making http timeouts shorter so tests which simulate communication timeouts finish in reasonable amount of time .put("exchange.http-client.idle-timeout", REQUEST_TIMEOUT.toString())